Clean-up composite configs (#34603)

* remove manual assignment tie-word-embeddings

* remove another unused attribute

* fix tests

* fix tests

* remove unnecessary overwrites

* fix

* decoder=True

* clean pix2struct

* run-all

* forgot `_tied_weights_keys` when adding Emu3

* also Aria + fix-copies

* and clean aria
This commit is contained in:
Raushan Turganbay 2025-01-15 10:04:07 +01:00 committed by GitHub
parent c61fcde910
commit 09d5f76274
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 68 additions and 219 deletions

View File

@ -249,9 +249,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def _update_causal_mask(
self,
attention_mask,

View File

@ -1862,7 +1862,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
"""
if getattr(self.config, "tie_word_embeddings", True):
if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True):
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@ -2104,7 +2104,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_num_tokens = new_embeddings.weight.shape[0]
# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
if (
self.get_output_embeddings() is not None
and not self.config.get_text_config(decoder=True).tie_word_embeddings
):
old_lm_head = self.get_output_embeddings()
if isinstance(old_lm_head, torch.nn.Embedding):
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
@ -4604,7 +4607,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_loaded_keys = loaded_keys
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
# If we're about to tie the output embeds to the input embeds we don't need to init them
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
if (
hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings")
and model.config.get_text_config(decoder=True).tie_word_embeddings
):
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None:
# Still need to initialize if there is a bias term since biases are not tied.

View File

@ -1360,6 +1360,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)
@ -1406,9 +1407,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def get_image_features(
self,
pixel_values: torch.FloatTensor,

View File

@ -1337,6 +1337,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)
@ -1383,9 +1384,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def get_image_features(
self,
pixel_values: torch.FloatTensor,

View File

@ -302,9 +302,6 @@ class Blip2Config(PretrainedConfig):
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.num_query_tokens = num_query_tokens
self.image_text_hidden_size = image_text_hidden_size
self.image_token_index = image_token_index

View File

@ -1433,7 +1433,7 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation)
text_model = CLIPTextModel._from_config(config)
self.text_model = text_model.text_model
self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
@ -1514,7 +1514,7 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation)
vision_model = CLIPVisionModel._from_config(config)
self.vision_model = vision_model.vision_model
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)

View File

@ -1791,6 +1791,8 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)

View File

@ -1103,6 +1103,8 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)

View File

@ -151,9 +151,9 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
@ -181,9 +181,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,

View File

@ -2104,9 +2104,7 @@ class GroundingDinoModel(GroundingDinoPreTrainedModel):
)
# Create text backbone
self.text_backbone = AutoModel.from_config(
config.text_config, add_pooling_layer=False, attn_implementation=config._attn_implementation
)
self.text_backbone = AutoModel.from_config(config.text_config, add_pooling_layer=False)
self.text_projection = nn.Linear(config.text_config.hidden_size, config.d_model)
if config.embedding_init_target or not config.two_stage:

View File

@ -1285,13 +1285,6 @@ class Idefics2Model(Idefics2PreTrainedModel):
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.text_model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
return model_embeds
def inputs_merger(
self,
input_ids: torch.LongTensor,
@ -1515,32 +1508,6 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
# model_embeds = self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds
# Update base model and current model config
# Ignore copy
self.config.text_config.vocab_size = model_embeds.weight.shape[0]
self.vocab_size = self.config.text_config.vocab_size
# Tie weights again if needed
self.tie_weights()
return model_embeds
def tie_weights(self):
"""
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if getattr(self.config, "tie_word_embeddings", True):
output_embeddings.weight = input_embeddings.weight
@add_start_docstrings_to_model_forward(IDEFICS2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Idefics2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -1094,17 +1094,6 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.tie_weights
def tie_weights(self):
"""
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if getattr(self.config, "tie_word_embeddings", True):
output_embeddings.weight = input_embeddings.weight
@add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -302,9 +302,6 @@ class InstructBlipConfig(PretrainedConfig):
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.num_query_tokens = num_query_tokens
self.image_token_index = image_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size

View File

@ -308,9 +308,6 @@ class InstructBlipVideoConfig(PretrainedConfig):
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.num_query_tokens = num_query_tokens
self.video_token_index = video_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size

View File

@ -137,9 +137,6 @@ class InstructBlipVideoConfig(PretrainedConfig):
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self.is_encoder_decoder = self.text_config.is_encoder_decoder
self.num_query_tokens = num_query_tokens
self.video_token_index = video_token_index
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size

View File

@ -242,6 +242,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
@ -264,16 +268,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):

View File

@ -357,6 +357,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
@ -395,18 +398,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_image_features(
self,
image_features,

View File

@ -397,6 +397,9 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.vision_resampler = LlavaNextVideoPooler(config)
@ -430,16 +433,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_image_features(
self,
image_features,

View File

@ -374,6 +374,9 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.post_init()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
@ -400,10 +403,6 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()
def pack_image_features(self, image_features, image_sizes, image_newline=None, vision_aspect_ratio="anyres_max_9"):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.

View File

@ -1986,6 +1986,9 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaForCausalLM._from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
@ -2011,9 +2014,6 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig")
def forward(

View File

@ -1914,12 +1914,9 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
self.embed_tokens = nn.ModuleList(
[nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)]
)
self.audio_encoder = AutoModel.from_config(
config.audio_encoder_config, attn_implementation=config._attn_implementation
)
self.audio_encoder = AutoModel.from_config(config.audio_encoder_config)
self.decoder = MoshiForCausalLM(config)
config.depth_decoder_config._attn_implementation_internal = config._attn_implementation
self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config)
self.num_codebooks = config.num_codebooks

View File

@ -335,10 +335,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
def tie_weights(self):
return self.language_model.tie_weights()
def _update_causal_mask(
self,
attention_mask,

View File

@ -14,9 +14,6 @@
# limitations under the License.
"""Pix2Struct model configuration"""
import os
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
@ -147,26 +144,6 @@ class Pix2StructTextConfig(PretrainedConfig):
**kwargs,
)
@classmethod
def from_pretrained(
cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs)
# get the text config dict if we are loading from Pix2StructConfig
if config_dict.get("model_type") == "pix2struct":
config_dict = config_dict["text_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class Pix2StructVisionConfig(PretrainedConfig):
r"""
@ -266,26 +243,6 @@ class Pix2StructVisionConfig(PretrainedConfig):
self.relative_attention_max_distance = relative_attention_max_distance
self.d_kv = d_kv
@classmethod
def from_pretrained(
cls, pretrainehidden_size_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrainehidden_size_name_or_path, **kwargs)
# get the vision config dict if we are loading from Pix2StructConfig
if config_dict.get("model_type") == "pix2struct":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class Pix2StructConfig(PretrainedConfig):
r"""

View File

@ -1733,14 +1733,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMi
def set_output_embeddings(self, new_embeddings):
self.decoder.set_output_embeddings(new_embeddings)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
# update vocab size
self.config.text_config.vocab_size = new_num_tokens
return model_embeds
def get_decoder(self):
return self.decoder

View File

@ -856,6 +856,9 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
self.multi_modal_projector = Qwen2AudioMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
@ -894,18 +897,6 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
def tie_weights(self):
return self.language_model.tie_weights()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_audio_features(
self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels
):

View File

@ -261,6 +261,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
return self.decoder.get_input_embeddings()
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()

View File

@ -245,6 +245,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
self.multi_modal_projector = VideoLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
@ -266,17 +269,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _merge_input_ids_with_visual_features(
self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1
):

View File

@ -242,6 +242,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
self.multi_modal_projector = VipLlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
@ -264,16 +268,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
# Ignore copy
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]):
"""

View File

@ -237,6 +237,9 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
return self.decoder.get_input_embeddings()
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
@ -659,12 +662,6 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past_key_values, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past_key_values, beam_idx)

View File

@ -1123,6 +1123,13 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
def setUp(self):
self.model_tester = Blip2ModelTester(self)
common_properties = ["image_token_index", "num_query_tokens", "image_text_hidden_size"]
self.config_tester = ConfigTester(
self, config_class=Blip2Config, has_text_modality=False, common_properties=common_properties
)
def test_config(self):
self.config_tester.run_common_tests()
def test_for_conditional_generation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -158,7 +158,10 @@ class InstructBlipVisionModelTest(ModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = InstructBlipVisionModelTester(self)
self.config_tester = ConfigTester(
self, config_class=InstructBlipVisionConfig, has_text_modality=False, hidden_size=37
self,
config_class=InstructBlipConfig,
has_text_modality=False,
common_properties=["num_query_tokens", "image_token_index"],
)
def test_config(self):

View File

@ -163,8 +163,9 @@ class InstructBlipVideoVisionModelTest(ModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = InstructBlipVideoVisionModelTester(self)
common_properties = ["num_query_tokens", "video_token_index"]
self.config_tester = ConfigTester(
self, config_class=InstructBlipVideoVisionConfig, has_text_modality=False, hidden_size=37
self, config_class=InstructBlipVideoConfig, has_text_modality=False, common_properties=common_properties
)
def test_config(self):

View File

@ -2283,7 +2283,7 @@ class ModelTesterMixin:
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = True
config.get_text_config().tie_word_embeddings = True
for model_class in self.all_model_classes:
model_tied = model_class(config)