diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index fed90c86b4a..f3faa60ca3d 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -258,6 +258,9 @@ class AriaConfig(PretrainedConfig): """ model_type = "aria" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/aria/convert_aria_weights_to_hf.py b/src/transformers/models/aria/convert_aria_weights_to_hf.py index dcc9e4d1397..a95f3cda834 100644 --- a/src/transformers/models/aria/convert_aria_weights_to_hf.py +++ b/src/transformers/models/aria/convert_aria_weights_to_hf.py @@ -106,7 +106,7 @@ def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, ol config.vision_config.hidden_size = 1152 config.vision_config.attention_heads = 16 config.pad_token_id = 2 - config.image_token_index = 9 + config.image_token_id = 9 config.intermediate_size = config.moe_intermediate_size config.auto_map = { "AutoConfig": "modeling_aria.AriaConfig", diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 7f88875f3e7..4dc9df7a515 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1507,11 +1507,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): if pixel_values is not None and inputs_embeds.shape[1] != 1: if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] else: - image_embeds = input_ids == self.config.image_token_index + image_embeds = input_ids == self.config.image_token_id special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) image_features = self.get_image_features( diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index b087c215366..add5bdc16b7 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -266,6 +266,9 @@ class AriaConfig(PretrainedConfig): """ model_type = "aria" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig} def __init__( @@ -1546,11 +1549,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): if pixel_values is not None and inputs_embeds.shape[1] != 1: if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] else: - image_embeds = input_ids == self.config.image_token_index + image_embeds = input_ids == self.config.image_token_id special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) image_features = self.get_image_features( diff --git a/src/transformers/models/aya_vision/configuration_aya_vision.py b/src/transformers/models/aya_vision/configuration_aya_vision.py index 574a5755abd..ad7fdfd319d 100644 --- a/src/transformers/models/aya_vision/configuration_aya_vision.py +++ b/src/transformers/models/aya_vision/configuration_aya_vision.py @@ -52,6 +52,9 @@ class AyaVisionConfig(PretrainedConfig): """ model_type = "aya_vision" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 13e3dfdb43b..45c2ab66e3e 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -444,10 +444,10 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi image_sizes=image_sizes, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index fa186633c43..db55e39ab73 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -273,6 +273,9 @@ class Blip2Config(PretrainedConfig): ```""" model_type = "blip-2" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "qformer_config": Blip2QFormerConfig, "vision_config": Blip2VisionConfig} def __init__( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 1ee48b00b81..d6ec49505bc 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2283,10 +2283,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds + # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "image_token_id", None) is not None: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) else: @@ -2406,8 +2406,8 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "image_token_index", None) is not None: - start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens + if getattr(self.config, "image_token_id", None) is not None: + start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) input_ids = input_ids.repeat(batch_size, 1) @@ -2415,10 +2415,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds + # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "image_token_id", None) is not None: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten() else: logger.warning_once( diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index 068e9a060be..6c0e4b9d809 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -285,6 +285,11 @@ class Gemma3Config(PretrainedConfig): ```""" model_type = "gemma3" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } sub_configs = { "text_config": Gemma3TextConfig, "vision_config": SiglipVisionConfig, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 170e3d952f3..316130ce9d1 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -1274,8 +1274,8 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): is_training = token_type_ids is not None and labels is not None # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: @@ -1296,10 +1296,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 7c95f63b0e0..90e6a4be2ff 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -257,6 +257,11 @@ class Gemma3Config(PretrainedConfig): ```""" model_type = "gemma3" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } sub_configs = { "text_config": Gemma3TextConfig, "vision_config": SiglipVisionConfig, @@ -922,8 +927,8 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): is_training = token_type_ids is not None and labels is not None # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: @@ -944,10 +949,10 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): diff --git a/src/transformers/models/got_ocr2/configuration_got_ocr2.py b/src/transformers/models/got_ocr2/configuration_got_ocr2.py index fb9a1fb6888..3e7dd7561f6 100644 --- a/src/transformers/models/got_ocr2/configuration_got_ocr2.py +++ b/src/transformers/models/got_ocr2/configuration_got_ocr2.py @@ -153,6 +153,9 @@ class GotOcr2Config(PretrainedConfig): ```""" model_type = "got_ocr2" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig} def __init__( diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 17db99b2a3e..d3a8a637ede 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -813,13 +813,13 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): if pixel_values is not None: image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index e3daafd81cc..aec8c5e1749 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -176,6 +176,9 @@ class GotOcr2Config(PretrainedConfig): ```""" model_type = "got_ocr2" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig} def __init__( @@ -477,13 +480,13 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration): if pixel_values is not None: image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/granite_speech/configuration_granite_speech.py b/src/transformers/models/granite_speech/configuration_granite_speech.py index e1355db41fd..b3e7e388a1a 100644 --- a/src/transformers/models/granite_speech/configuration_granite_speech.py +++ b/src/transformers/models/granite_speech/configuration_granite_speech.py @@ -147,6 +147,9 @@ class GraniteSpeechConfig(PretrainedConfig): ```""" model_type = "granite_speech" + attribute_map = { + "audio_token_id": "audio_token_index", + } sub_configs = { "text_config": AutoConfig, "encoder_config": GraniteSpeechEncoderConfig, diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 6ae4e6c35ff..55842e7426e 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -515,7 +515,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera # Get the base embeddings; set all audio tokens to 0 index # to avoid out of vocabulary issues with the LLM embedding. # Audio features will be masked into is_audio_idx indices later. - is_audio_idx = input_ids == self.config.audio_token_index + is_audio_idx = input_ids == self.config.audio_token_id llm_input_ids = input_ids.clone() llm_input_ids[is_audio_idx] = 0 inputs_embeds = self.get_input_embeddings()(llm_input_ids) @@ -624,7 +624,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera input_features_mask (`torch.Tensor`, *optional*, defaults to `None`) Mask to be applied to audio features prior to scattering into the language embeddings. """ - is_audio_index = input_ids == self.config.audio_token_index + is_audio_index = input_ids == self.config.audio_token_id llm_input_ids = torch.where(is_audio_index, 0, input_ids) inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size] diff --git a/src/transformers/models/instructblip/configuration_instructblip.py b/src/transformers/models/instructblip/configuration_instructblip.py index 328d64761a5..04c32d552da 100644 --- a/src/transformers/models/instructblip/configuration_instructblip.py +++ b/src/transformers/models/instructblip/configuration_instructblip.py @@ -268,6 +268,9 @@ class InstructBlipConfig(PretrainedConfig): ```""" model_type = "instructblip" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = { "text_config": AutoConfig, "qformer_config": InstructBlipQFormerConfig, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index a304353cc41..cdfb59b5804 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1467,10 +1467,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds + # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "image_token_id", None) is not None: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten() else: logger.warning_once( @@ -1599,8 +1599,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "image_token_index", None) is not None: - start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens + if getattr(self.config, "image_token_id", None) is not None: + start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) input_ids = input_ids.repeat(batch_size, 1) @@ -1609,10 +1609,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "image_token_index" then the input is expanded to account for image embeds + # if the model already has "image_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "image_token_index", None) is not None: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "image_token_id", None) is not None: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) else: logger.warning_once( diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index d4cdf659763..d7611395968 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -274,6 +274,9 @@ class InstructBlipVideoConfig(PretrainedConfig): ```""" model_type = "instructblipvideo" + attribute_map = { + "video_token_id": "video_token_index", + } sub_configs = { "text_config": AutoConfig, "qformer_config": InstructBlipVideoQFormerConfig, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index d2a6c7b6f12..e9d9e4938a9 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1495,10 +1495,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds + # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "video_token_id", None) is not None: + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) else: logger.warning_once( @@ -1635,8 +1635,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "video_token_index", None) is not None: - start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens + if getattr(self.config, "video_token_id", None) is not None: + start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) input_ids = input_ids.repeat(batch_size, 1) @@ -1645,10 +1645,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds + # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "video_token_id", None) is not None: + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) else: logger.warning_once( diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index ed2364edce2..212050877a4 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -106,6 +106,9 @@ class InstructBlipVideoConfig(PretrainedConfig): ```""" model_type = "instructblipvideo" + attribute_map = { + "video_token_id": "video_token_index", + } sub_configs = { "text_config": AutoConfig, "qformer_config": InstructBlipVideoQFormerConfig, @@ -315,10 +318,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera if attention_mask is None: attention_mask = torch.ones_like(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds + # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "video_token_id", None) is not None: + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) else: logger.warning_once( @@ -455,8 +458,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera if input_ids is None: start_tokens = [self.config.text_config.bos_token_id] - if getattr(self.config, "video_token_index", None) is not None: - start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens + if getattr(self.config, "video_token_id", None) is not None: + start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) input_ids = input_ids.repeat(batch_size, 1) @@ -465,10 +468,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera inputs_embeds = self.get_input_embeddings()(input_ids) - # if the model already has "video_token_index" then the input is expanded to account for image embeds + # if the model already has "video_token_id" then the input is expanded to account for image embeds # otherwise we expand manually by concatenating - if getattr(self.config, "video_token_index", None) is not None: - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + if getattr(self.config, "video_token_id", None) is not None: + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) else: logger.warning_once( diff --git a/src/transformers/models/janus/configuration_janus.py b/src/transformers/models/janus/configuration_janus.py index de727ab6d07..d90f64b3831 100644 --- a/src/transformers/models/janus/configuration_janus.py +++ b/src/transformers/models/janus/configuration_janus.py @@ -230,6 +230,8 @@ class JanusConfig(PretrainedConfig): The config object or dictionary of the vision backbone. vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`): The config object or dictionary of the VQVAE backbone. + image_token_id (`int`, *optional*, defaults to 100581): + Token index of a placeholder image token. Example: @@ -262,7 +264,14 @@ class JanusConfig(PretrainedConfig): "vq_config": JanusVQVAEConfig, } - def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs): + def __init__( + self, + text_config=None, + vision_config=None, + vq_config=None, + image_token_id=100581, + **kwargs, + ): if isinstance(text_config, dict): text_config["model_type"] = text_config.get("model_type", "llama") self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) @@ -307,7 +316,7 @@ class JanusConfig(PretrainedConfig): # This dimension is required when decoding discrete image tokens to continuous input. self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size # The default is only the index for the 1B model, 7B uses a different one - self.image_token_index = kwargs.get("image_token_index", 100581) + self.image_token_id = image_token_id super().__init__(**kwargs) diff --git a/src/transformers/models/janus/convert_janus_weights_to_hf.py b/src/transformers/models/janus/convert_janus_weights_to_hf.py index 32e16780bbe..dc47f4ee8e5 100644 --- a/src/transformers/models/janus/convert_janus_weights_to_hf.py +++ b/src/transformers/models/janus/convert_janus_weights_to_hf.py @@ -390,7 +390,7 @@ def convert_model( text_config=text_config, vision_config=vision_config, vq_config=vq_config, - image_token_index=tokenizer.vocab.get(""), + image_token_id=tokenizer.vocab.get(""), ) # Save the config diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index a4c7937bc45..2a0af557f90 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1281,7 +1281,7 @@ class JanusModel(JanusPreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values) - image_attention_mask = input_ids == self.config.image_token_index + image_attention_mask = input_ids == self.config.image_token_id embed_dim = inputs_embeds.shape[-1] image_features = image_embeds.reshape(-1, embed_dim) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 03e3a05a27a..3a0efff5ae2 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -306,6 +306,8 @@ class JanusConfig(PretrainedConfig): The config object or dictionary of the vision backbone. vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`): The config object or dictionary of the VQVAE backbone. + image_token_id (`int`, *optional*, defaults to 100581): + Token index of a placeholder image token. Example: @@ -338,7 +340,14 @@ class JanusConfig(PretrainedConfig): "vq_config": JanusVQVAEConfig, } - def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs): + def __init__( + self, + text_config=None, + vision_config=None, + vq_config=None, + image_token_id=100581, + **kwargs, + ): if isinstance(text_config, dict): text_config["model_type"] = text_config.get("model_type", "llama") self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) @@ -383,7 +392,7 @@ class JanusConfig(PretrainedConfig): # This dimension is required when decoding discrete image tokens to continuous input. self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size # The default is only the index for the 1B model, 7B uses a different one - self.image_token_index = kwargs.get("image_token_index", 100581) + self.image_token_id = image_token_id super().__init__(**kwargs) @@ -1081,7 +1090,7 @@ class JanusModel(JanusPreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values) - image_attention_mask = input_ids == self.config.image_token_index + image_attention_mask = input_ids == self.config.image_token_id embed_dim = inputs_embeds.shape[-1] image_features = image_embeds.reshape(-1, embed_dim) diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py index c4cef4d4ab5..e296d4e3c68 100644 --- a/src/transformers/models/llama4/configuration_llama4.py +++ b/src/transformers/models/llama4/configuration_llama4.py @@ -395,6 +395,11 @@ class Llama4Config(PretrainedConfig): ```""" model_type = "llama4" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig} base_model_tp_plan = { "multi_modal_projector.linear_1": "colwise_rep", diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 021bae7f62c..985c0c0d2f4 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -1745,7 +1745,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.multi_modal_projector(vision_flat) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) final_mask = special_image_mask.to(inputs_embeds.device) inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index f476591b2eb..7b6ad8d0e0e 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -75,6 +75,9 @@ class LlavaConfig(PretrainedConfig): ```""" model_type = "llava" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/llava/convert_llava_weights_to_hf.py b/src/transformers/models/llava/convert_llava_weights_to_hf.py index 3582b9772c9..dafbf8bf2f6 100644 --- a/src/transformers/models/llava/convert_llava_weights_to_hf.py +++ b/src/transformers/models/llava/convert_llava_weights_to_hf.py @@ -129,13 +129,13 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o # llms-lab interleeave models do not use any selection startegy except for last hidden state if "Qwen" in text_model_id: - config.image_token_index = 151646 + config.image_token_id = 151646 if "siglip" in vision_model_id: config.vision_feature_select_strategy = "full" config.vision_feature_layer = -1 else: config.pad_token_id = 32001 - config.image_token_index = 32000 + config.image_token_id = 32000 with torch.device("meta"): model = LlavaForConditionalGeneration(config) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index ba5277d4ff3..bc78b571d95 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -405,10 +405,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): image_sizes=image_sizes, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/llava_next/configuration_llava_next.py b/src/transformers/models/llava_next/configuration_llava_next.py index 3836dbf71cd..8e2ba1db75d 100644 --- a/src/transformers/models/llava_next/configuration_llava_next.py +++ b/src/transformers/models/llava_next/configuration_llava_next.py @@ -80,6 +80,9 @@ class LlavaNextConfig(PretrainedConfig): ```""" model_type = "llava_next" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 85f21d4a5be..5d6e098d03c 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -102,25 +102,25 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): if model_id == "liuhaotian/llava-v1.6-mistral-7b": text_model_id = "mistralai/Mistral-7B-Instruct-v0.2" - image_token_index = 32000 + image_token_id = 32000 elif model_id == "liuhaotian/llava-v1.6-vicuna-7b": text_model_id = "lmsys/vicuna-7b-v1.5" - image_token_index = 32000 + image_token_id = 32000 elif model_id == "liuhaotian/llava-v1.6-vicuna-13b": text_model_id = "lmsys/vicuna-13b-v1.5" - image_token_index = 32000 + image_token_id = 32000 elif model_id == "liuhaotian/llava-v1.6-34b": text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B" - image_token_index = 64000 + image_token_id = 64000 elif model_id == "lmms-lab/llama3-llava-next-8b": text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" - image_token_index = 128256 + image_token_id = 128256 elif model_id == "lmms-lab/llava-next-72b": text_model_id = "Qwen/Qwen1.5-72B-Chat" - image_token_index = 151646 + image_token_id = 151646 elif model_id == "lmms-lab/llava-next-110b": text_model_id = "Qwen/Qwen1.5-110B-Chat" - image_token_index = 151646 + image_token_id = 151646 vision_model_id = data["mm_vision_tower"] @@ -142,7 +142,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): text_config=text_config.to_dict(), image_grid_pinpoints=image_processor.image_grid_pinpoints, use_image_newline_parameter=True, - image_token_index=image_token_index, + image_token_id=image_token_id, ) with init_empty_weights(): @@ -225,8 +225,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): if model_id == "liuhaotian/llava-v1.6-mistral-7b": filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset") original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True) - # replace -200 by image_token_index (since we use token ID = 32000 for the image token) - original_input_ids[original_input_ids == -200] = image_token_index + # replace -200 by image_token_id (since we use token ID = 32000 for the image token) + original_input_ids[original_input_ids == -200] = image_token_id assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() elif model_id == "liuhaotian/llava-v1.6-34b": @@ -234,8 +234,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): repo_id="nielsr/test-image", filename="llava_1_6_34b_input_ids.pt", repo_type="dataset" ) original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True) - # replace -200 by image_token_index - original_input_ids[original_input_ids == -200] = image_token_index + # replace -200 by image_token_id + original_input_ids[original_input_ids == -200] = image_token_id assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index b3eae9c4431..90cf7ea252b 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -624,10 +624,10 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi image_newline=self.image_newline, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index 01450f6b587..8a5f7fecee3 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -88,6 +88,10 @@ class LlavaNextVideoConfig(PretrainedConfig): ```""" model_type = "llava_next_video" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py b/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py index aae44eee97a..2877b2e9dd8 100644 --- a/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py +++ b/src/transformers/models/llava_next_video/convert_llava_next_video_weights_to_hf.py @@ -159,18 +159,18 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): if model_id == "lmms-lab/LLaVA-NeXT-Video-7B-32K": text_model_id = "mistralai/Mistral-7B-Instruct-v0.2" - video_token_index = 32000 - image_token_index = 32001 + video_token_id = 32000 + image_token_id = 32001 overwrite_text_config = {} elif model_id in ["lmms-lab/LLaVA-NeXT-Video-7B", "lmms-lab/LLaVA-NeXT-Video-7B-DPO"]: text_model_id = "lmsys/vicuna-7b-v1.5" - video_token_index = 32000 - image_token_index = 32001 + video_token_id = 32000 + image_token_id = 32001 overwrite_text_config = {"factor": 2.0, "type": "linear"} elif model_id in ["lmms-lab/LLaVA-NeXT-Video-34B", "lmms-lab/LLaVA-NeXT-Video-34B-DPO"]: text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B" - video_token_index = 64000 - image_token_index = 64001 + video_token_id = 64000 + image_token_id = 64001 overwrite_text_config = {} else: raise ValueError("Incorrect checkpoint referenced. Text model-id not identified!") @@ -199,8 +199,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): text_config=text_config, image_grid_pinpoints=image_processor.image_grid_pinpoints, use_image_newline_parameter=True, - video_token_index=video_token_index, - image_token_index=image_token_index, + video_token_id=video_token_id, + image_token_id=image_token_id, ) with init_empty_weights(): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 1965f08bb46..c89e2d72ef8 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -703,10 +703,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene image_newline=self.image_newline, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -725,10 +725,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 0109082d841..b0b744c5b32 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -104,6 +104,10 @@ class LlavaNextVideoConfig(PretrainedConfig): ```""" model_type = "llava_next_video" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( @@ -489,10 +493,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): image_newline=self.image_newline, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -511,10 +515,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/llava_onevision/configuration_llava_onevision.py b/src/transformers/models/llava_onevision/configuration_llava_onevision.py index c3d43d69d76..25f9c30d933 100644 --- a/src/transformers/models/llava_onevision/configuration_llava_onevision.py +++ b/src/transformers/models/llava_onevision/configuration_llava_onevision.py @@ -85,6 +85,10 @@ class LlavaOnevisionConfig(PretrainedConfig): ```""" model_type = "llava_onevision" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index ec928380914..dfd43643958 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -700,10 +700,10 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene vision_aspect_ratio=vision_aspect_ratio, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -724,10 +724,10 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_index).sum() + n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/mistral3/configuration_mistral3.py b/src/transformers/models/mistral3/configuration_mistral3.py index e7b27d57220..357ba2cddd4 100644 --- a/src/transformers/models/mistral3/configuration_mistral3.py +++ b/src/transformers/models/mistral3/configuration_mistral3.py @@ -68,6 +68,9 @@ class Mistral3Config(PretrainedConfig): ```""" model_type = "mistral3" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} is_composition = True diff --git a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py index 11b2d18f04d..c8f9b64ab1f 100644 --- a/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py +++ b/src/transformers/models/mistral3/convert_mistral3_weights_to_hf.py @@ -161,7 +161,7 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072) vision_config=new_vision_config, text_config=new_text_config, multimodal_projector_bias=adapter_bias, - image_token_index=image_token_id, + image_token_id=image_token_id, spatial_merge_size=spatial_merge_size, vision_feature_layer=-1, ) diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index c27a646dd68..5ce7763dd7c 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -449,10 +449,10 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) image_sizes=image_sizes, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 5eebcd8d560..36fd4526838 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -233,10 +233,10 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration): image_sizes=image_sizes, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/mllama/configuration_mllama.py b/src/transformers/models/mllama/configuration_mllama.py index fc655c9944e..b501e9efd35 100644 --- a/src/transformers/models/mllama/configuration_mllama.py +++ b/src/transformers/models/mllama/configuration_mllama.py @@ -337,6 +337,9 @@ class MllamaConfig(PretrainedConfig): ```""" model_type = "mllama" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": MllamaTextConfig, "vision_config": MllamaVisionConfig} def __init__( diff --git a/src/transformers/models/paligemma/configuration_paligemma.py b/src/transformers/models/paligemma/configuration_paligemma.py index 918095ce85e..4551b85bcd5 100644 --- a/src/transformers/models/paligemma/configuration_paligemma.py +++ b/src/transformers/models/paligemma/configuration_paligemma.py @@ -73,6 +73,9 @@ class PaliGemmaConfig(PretrainedConfig): ```""" model_type = "paligemma" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} keys_to_ignore_at_inference = ["past_key_values"] diff --git a/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py b/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py index df869fcefb2..3334e6f28fc 100644 --- a/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py +++ b/src/transformers/models/paligemma/convert_paligemma2_weights_to_hf.py @@ -80,7 +80,7 @@ DTYPES = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch def get_paligemma2_config(variant: str, precision: str): config = { - "image_token_index": None, + "image_token_id": None, "pad_token_id": 0, "bos_token_id": 2, "eos_token_id": 1, @@ -93,7 +93,7 @@ def get_paligemma2_config(variant: str, precision: str): patch_size = 14 num_image_tokens = (image_size**2) // (patch_size**2) config["projection_dim"] = variant_config["hidden_size"] - config["image_token_index"] = 257152 + config["image_token_id"] = 257152 config["num_hidden_layers"] = variant_config["num_hidden_layers"] # For generate text_config = Gemma2Config.from_pretrained("google/gemma-2-2b-it").to_dict() sup_text_config = { diff --git a/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py b/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py index bcea5372e57..054872a799a 100644 --- a/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py +++ b/src/transformers/models/paligemma/convert_paligemma_weights_to_hf.py @@ -45,7 +45,7 @@ PALIGEMMA_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"] def get_paligemma_config(variant: str, precision: str): config = { - "image_token_index": None, + "image_token_id": None, "pad_token_id": 0, "bos_token_id": 2, "eos_token_id": 1, @@ -58,7 +58,7 @@ def get_paligemma_config(variant: str, precision: str): patch_size = 14 num_image_tokens = (image_size**2) // (patch_size**2) - config["image_token_index"] = 257152 if variant != "2b-test" else 256000 + config["image_token_id"] = 257152 if variant != "2b-test" else 256000 text_config = { "vocab_size": 257152, "num_hidden_layers": 18, diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index b7ac3b751b5..ade69444f46 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -473,8 +473,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi is_training = token_type_ids is not None and labels is not None # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: @@ -498,10 +498,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index da5af9fb344..373aa6cb6e4 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -187,7 +187,7 @@ def convert_mistral_model(input_dir, output_dir): vision_config, text_config, vision_feature_layer=-1, - image_token_index=10, + image_token_id=10, vision_feature_select_strategy="full", image_seq_length=1, multimodal_projector_bias=adapter_bias, diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index d079846708c..22873e16d07 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -458,6 +458,11 @@ class Qwen2_5OmniThinkerConfig(PretrainedConfig): ```""" model_type = "qwen2_5_omni_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } sub_configs = { "audio_config": Qwen2_5OmniAudioEncoderConfig, "vision_config": Qwen2_5OmniVisionEncoderConfig, @@ -662,6 +667,11 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig): ```""" model_type = "qwen2_5_omni_talker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } def __init__( self, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 2898354399b..8accfc638e3 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -334,9 +334,9 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) """ spatial_merge_size = self.spatial_merge_size - image_token_id = self.config.image_token_index - video_token_id = self.config.video_token_index - audio_token_id = self.config.audio_token_index + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id vision_start_token_id = self.config.vision_start_token_id audio_start_token_id = self.config.audio_start_token_id position_id_per_seconds = self.config.position_id_per_seconds @@ -2450,7 +2450,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo if audio_features.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") audio_mask = ( - (input_ids == self.config.audio_token_index) + (input_ids == self.config.audio_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) @@ -2462,7 +2462,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_mask = ( - (input_ids == self.config.image_token_index) + (input_ids == self.config.image_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) @@ -2474,7 +2474,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) video_mask = ( - (input_ids == self.config.video_token_index) + (input_ids == self.config.video_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 9bb81ddcc59..2524fd9186a 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -443,6 +443,11 @@ class Qwen2_5OmniThinkerConfig(PretrainedConfig): ```""" model_type = "qwen2_5_omni_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } sub_configs = { "audio_config": Qwen2_5OmniAudioEncoderConfig, "vision_config": Qwen2_5OmniVisionEncoderConfig, @@ -647,6 +652,11 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig): ```""" model_type = "qwen2_5_omni_talker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } def __init__( self, @@ -1225,9 +1235,9 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) """ spatial_merge_size = self.spatial_merge_size - image_token_id = self.config.image_token_index - video_token_id = self.config.video_token_index - audio_token_id = self.config.audio_token_index + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id vision_start_token_id = self.config.vision_start_token_id audio_start_token_id = self.config.audio_start_token_id position_id_per_seconds = self.config.position_id_per_seconds @@ -2400,7 +2410,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo if audio_features.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") audio_mask = ( - (input_ids == self.config.audio_token_index) + (input_ids == self.config.audio_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) @@ -2412,7 +2422,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_mask = ( - (input_ids == self.config.image_token_index) + (input_ids == self.config.image_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) @@ -2424,7 +2434,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) video_mask = ( - (input_ids == self.config.video_token_index) + (input_ids == self.config.video_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index 588ad214ebd..a700a992886 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -156,6 +156,10 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + image_token_id (`int`, *optional*): + Token index used as placeholder for image embeddings. + video_token_id (`int`, *optional*): + Token index used as placeholder for video embeddings. ```python >>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig @@ -209,6 +213,8 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): max_window_layers=80, attention_dropout=0.0, rope_scaling=None, + image_token_id=None, + video_token_id=None, **kwargs, ): self.vocab_size = vocab_size @@ -244,6 +250,8 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): self.rope_scaling["type"] = "default" self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self, ignore_keys={"mrope_section"}) + self.image_token_id = image_token_id + self.video_token_id = video_token_id super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py index 53bd4df9c31..23a1c699dd4 100644 --- a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py @@ -157,6 +157,9 @@ class Qwen2AudioConfig(PretrainedConfig): ```""" model_type = "qwen2_audio" + attribute_map = { + "audio_token_id": "audio_token_index", + } sub_configs = {"text_config": AutoConfig, "audio_config": AutoConfig} def __init__( diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index ad586a45c50..4496ef73b8c 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -932,7 +932,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") # 1. Create a mask to know where special audio tokens are - special_audio_token_mask = input_ids == self.config.audio_token_index + special_audio_token_mask = input_ids == self.config.audio_token_id num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) # In case the Audio model or the Language model has been offloaded to CPU, we need to manually @@ -942,7 +942,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi input_ids = input_ids.to(target_device) num_audio_tokens = num_audio_tokens.to(target_device) batch_indices, non_audio_indices = torch.where( - (input_ids != self.config.audio_token_index) & (attention_mask == 1) + (input_ids != self.config.audio_token_id) & (attention_mask == 1) ) # 2. Compute the positions where text should be written @@ -1114,7 +1114,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi audio_features = self.multi_modal_projector(selected_audio_feature) # if we have consecutive audio tokens, then it means we expanded input_ids in processing - audio_tokens = input_ids == self.config.audio_token_index + audio_tokens = input_ids == self.config.audio_token_id legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0 if legacy_processing: @@ -1130,14 +1130,14 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi audio_features_mask = audio_features_mask < audio_output_lengths[:, None] audio_features = audio_features[audio_features_mask] - n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item() + n_audio_tokens = (input_ids == self.config.audio_token_id).sum().item() n_audio_features = audio_features.shape[0] if n_audio_tokens != n_audio_features: raise ValueError( f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}" ) - special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device) + special_audio_mask = (input_ids == self.config.audio_token_id).to(inputs_embeds.device) special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index ee2ed40e463..5f4842ce604 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -145,6 +145,10 @@ class Qwen2VLTextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + image_token_id (`int`, *optional*): + Token index used as placeholder for image embeddings. + video_token_id (`int`, *optional*): + Token index used as placeholder for video embeddings. ```python >>> from transformers import Qwen2VLTextModel, Qwen2VLConfig @@ -198,6 +202,8 @@ class Qwen2VLTextConfig(PretrainedConfig): max_window_layers=80, attention_dropout=0.0, rope_scaling=None, + image_token_id=None, + video_token_id=None, **kwargs, ): self.vocab_size = vocab_size @@ -233,6 +239,8 @@ class Qwen2VLTextConfig(PretrainedConfig): self.rope_scaling["type"] = "default" self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self, ignore_keys={"mrope_section"}) + self.image_token_id = image_token_id + self.video_token_id = video_token_id super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py b/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py index 8094cb14b43..acb53924b2e 100644 --- a/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py +++ b/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py @@ -72,6 +72,11 @@ class ShieldGemma2Config(PretrainedConfig): ```""" model_type = "shieldgemma2" + attribute_map = { + "image_token_id": "image_token_index", + "boi_token_id": "boi_token_index", + "eoi_token_id": "eoi_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py index e761481d825..402d9946731 100644 --- a/src/transformers/models/video_llava/configuration_video_llava.py +++ b/src/transformers/models/video_llava/configuration_video_llava.py @@ -80,6 +80,10 @@ class VideoLlavaConfig(PretrainedConfig): ```""" model_type = "video_llava" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index f092bba196d..70ebef344c3 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -493,10 +493,10 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -509,10 +509,10 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_tokens = (input_ids == self.config.video_token_index).sum() + n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_features.shape[0] * video_features.shape[1] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/vipllava/configuration_vipllava.py b/src/transformers/models/vipllava/configuration_vipllava.py index ac24cce2412..f748201f7d7 100644 --- a/src/transformers/models/vipllava/configuration_vipllava.py +++ b/src/transformers/models/vipllava/configuration_vipllava.py @@ -70,6 +70,9 @@ class VipLlavaConfig(PretrainedConfig): ```""" model_type = "vipllava" + attribute_map = { + "image_token_id": "image_token_index", + } sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} def __init__( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index c6060b756eb..9243bbe9e2d 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -374,10 +374,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index fa8bd274cce..5ae15113108 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -224,12 +224,9 @@ class GenerationTesterMixin: # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens. if config is not None: for key in [ - "image_token_index", "image_token_id", - "video_token_index", "video_token_id", "vision_start_token_id", - "audio_token_index", "audio_start_token_id", "audio_end_token_id", "vision_end_token_id", diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index a0b31c8cc02..ca677f6549b 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -351,8 +351,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "pad_index", "unk_index", "mask_index", - "image_token_index", # for VLMs - "video_token_index", + "image_token_id", # for VLMs + "video_token_id", "image_seq_length", "video_seq_length", "image_size",