mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[VLMs] use only xxx_token_id
for multimodal tokens (#37573)
* use only `xxx_token_id` for multimodal tokens * update modeling files as well * fixup * why fixup doesn't fix modular docstring first? * janus, need to update configs in the hub still * last fixup
This commit is contained in:
parent
4afd3f4820
commit
2ba6b92a6f
@ -258,6 +258,9 @@ class AriaConfig(PretrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "aria"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -106,7 +106,7 @@ def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, ol
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.attention_heads = 16
|
||||
config.pad_token_id = 2
|
||||
config.image_token_index = 9
|
||||
config.image_token_id = 9
|
||||
config.intermediate_size = config.moe_intermediate_size
|
||||
config.auto_map = {
|
||||
"AutoConfig": "modeling_aria.AriaConfig",
|
||||
|
@ -1507,11 +1507,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
else:
|
||||
image_embeds = input_ids == self.config.image_token_index
|
||||
image_embeds = input_ids == self.config.image_token_id
|
||||
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
|
||||
image_features = self.get_image_features(
|
||||
|
@ -266,6 +266,9 @@ class AriaConfig(PretrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "aria"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
@ -1546,11 +1549,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
else:
|
||||
image_embeds = input_ids == self.config.image_token_index
|
||||
image_embeds = input_ids == self.config.image_token_id
|
||||
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
|
||||
image_features = self.get_image_features(
|
||||
|
@ -52,6 +52,9 @@ class AyaVisionConfig(PretrainedConfig):
|
||||
"""
|
||||
|
||||
model_type = "aya_vision"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -444,10 +444,10 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -273,6 +273,9 @@ class Blip2Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "blip-2"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "qformer_config": Blip2QFormerConfig, "vision_config": Blip2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -2283,10 +2283,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
@ -2406,8 +2406,8 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
@ -2415,10 +2415,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
|
@ -285,6 +285,11 @@ class Gemma3Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "gemma3"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"boi_token_id": "boi_token_index",
|
||||
"eoi_token_id": "eoi_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"text_config": Gemma3TextConfig,
|
||||
"vision_config": SiglipVisionConfig,
|
||||
|
@ -1274,8 +1274,8 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
@ -1296,10 +1296,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
|
@ -257,6 +257,11 @@ class Gemma3Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "gemma3"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"boi_token_id": "boi_token_index",
|
||||
"eoi_token_id": "eoi_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"text_config": Gemma3TextConfig,
|
||||
"vision_config": SiglipVisionConfig,
|
||||
@ -922,8 +927,8 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
@ -944,10 +949,10 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
|
@ -153,6 +153,9 @@ class GotOcr2Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "got_ocr2"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -813,13 +813,13 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
@ -176,6 +176,9 @@ class GotOcr2Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "got_ocr2"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": GotOcr2VisionConfig}
|
||||
|
||||
def __init__(
|
||||
@ -477,13 +480,13 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
@ -147,6 +147,9 @@ class GraniteSpeechConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "granite_speech"
|
||||
attribute_map = {
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"text_config": AutoConfig,
|
||||
"encoder_config": GraniteSpeechEncoderConfig,
|
||||
|
@ -515,7 +515,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
||||
# Get the base embeddings; set all audio tokens to 0 index
|
||||
# to avoid out of vocabulary issues with the LLM embedding.
|
||||
# Audio features will be masked into is_audio_idx indices later.
|
||||
is_audio_idx = input_ids == self.config.audio_token_index
|
||||
is_audio_idx = input_ids == self.config.audio_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[is_audio_idx] = 0
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
@ -624,7 +624,7 @@ class GraniteSpeechForConditionalGeneration(GraniteSpeechPreTrainedModel, Genera
|
||||
input_features_mask (`torch.Tensor`, *optional*, defaults to `None`)
|
||||
Mask to be applied to audio features prior to scattering into the language embeddings.
|
||||
"""
|
||||
is_audio_index = input_ids == self.config.audio_token_index
|
||||
is_audio_index = input_ids == self.config.audio_token_id
|
||||
llm_input_ids = torch.where(is_audio_index, 0, input_ids)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
|
||||
|
||||
|
@ -268,6 +268,9 @@ class InstructBlipConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "instructblip"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"text_config": AutoConfig,
|
||||
"qformer_config": InstructBlipQFormerConfig,
|
||||
|
@ -1467,10 +1467,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
else:
|
||||
logger.warning_once(
|
||||
@ -1599,8 +1599,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
@ -1609,10 +1609,10 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "image_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
else:
|
||||
logger.warning_once(
|
||||
|
@ -274,6 +274,9 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "instructblipvideo"
|
||||
attribute_map = {
|
||||
"video_token_id": "video_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"text_config": AutoConfig,
|
||||
"qformer_config": InstructBlipVideoQFormerConfig,
|
||||
|
@ -1495,10 +1495,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
else:
|
||||
logger.warning_once(
|
||||
@ -1635,8 +1635,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
@ -1645,10 +1645,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
else:
|
||||
logger.warning_once(
|
||||
|
@ -106,6 +106,9 @@ class InstructBlipVideoConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "instructblipvideo"
|
||||
attribute_map = {
|
||||
"video_token_id": "video_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"text_config": AutoConfig,
|
||||
"qformer_config": InstructBlipVideoQFormerConfig,
|
||||
@ -315,10 +318,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
else:
|
||||
logger.warning_once(
|
||||
@ -455,8 +458,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
@ -465,10 +468,10 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "video_token_index" then the input is expanded to account for image embeds
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_index", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
else:
|
||||
logger.warning_once(
|
||||
|
@ -230,6 +230,8 @@ class JanusConfig(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`):
|
||||
The config object or dictionary of the VQVAE backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 100581):
|
||||
Token index of a placeholder image token.
|
||||
|
||||
Example:
|
||||
|
||||
@ -262,7 +264,14 @@ class JanusConfig(PretrainedConfig):
|
||||
"vq_config": JanusVQVAEConfig,
|
||||
}
|
||||
|
||||
def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
vq_config=None,
|
||||
image_token_id=100581,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config.get("model_type", "llama")
|
||||
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
@ -307,7 +316,7 @@ class JanusConfig(PretrainedConfig):
|
||||
# This dimension is required when decoding discrete image tokens to continuous input.
|
||||
self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
|
||||
# The default is only the index for the 1B model, 7B uses a different one
|
||||
self.image_token_index = kwargs.get("image_token_index", 100581)
|
||||
self.image_token_id = image_token_id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
|
@ -390,7 +390,7 @@ def convert_model(
|
||||
text_config=text_config,
|
||||
vision_config=vision_config,
|
||||
vq_config=vq_config,
|
||||
image_token_index=tokenizer.vocab.get("<image_placeholder>"),
|
||||
image_token_id=tokenizer.vocab.get("<image_placeholder>"),
|
||||
)
|
||||
|
||||
# Save the config
|
||||
|
@ -1281,7 +1281,7 @@ class JanusModel(JanusPreTrainedModel):
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
image_attention_mask = input_ids == self.config.image_token_index
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
embed_dim = inputs_embeds.shape[-1]
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
|
@ -306,6 +306,8 @@ class JanusConfig(PretrainedConfig):
|
||||
The config object or dictionary of the vision backbone.
|
||||
vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`):
|
||||
The config object or dictionary of the VQVAE backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 100581):
|
||||
Token index of a placeholder image token.
|
||||
|
||||
Example:
|
||||
|
||||
@ -338,7 +340,14 @@ class JanusConfig(PretrainedConfig):
|
||||
"vq_config": JanusVQVAEConfig,
|
||||
}
|
||||
|
||||
def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
vq_config=None,
|
||||
image_token_id=100581,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config.get("model_type", "llama")
|
||||
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
@ -383,7 +392,7 @@ class JanusConfig(PretrainedConfig):
|
||||
# This dimension is required when decoding discrete image tokens to continuous input.
|
||||
self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
|
||||
# The default is only the index for the 1B model, 7B uses a different one
|
||||
self.image_token_index = kwargs.get("image_token_index", 100581)
|
||||
self.image_token_id = image_token_id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@ -1081,7 +1090,7 @@ class JanusModel(JanusPreTrainedModel):
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
image_attention_mask = input_ids == self.config.image_token_index
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
embed_dim = inputs_embeds.shape[-1]
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
|
@ -395,6 +395,11 @@ class Llama4Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llama4"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"boi_token_id": "boi_token_index",
|
||||
"eoi_token_id": "eoi_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig}
|
||||
base_model_tp_plan = {
|
||||
"multi_modal_projector.linear_1": "colwise_rep",
|
||||
|
@ -1745,7 +1745,7 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||
|
||||
|
@ -75,6 +75,9 @@ class LlavaConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llava"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -129,13 +129,13 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o
|
||||
|
||||
# llms-lab interleeave models do not use any selection startegy except for last hidden state
|
||||
if "Qwen" in text_model_id:
|
||||
config.image_token_index = 151646
|
||||
config.image_token_id = 151646
|
||||
if "siglip" in vision_model_id:
|
||||
config.vision_feature_select_strategy = "full"
|
||||
config.vision_feature_layer = -1
|
||||
else:
|
||||
config.pad_token_id = 32001
|
||||
config.image_token_index = 32000
|
||||
config.image_token_id = 32000
|
||||
|
||||
with torch.device("meta"):
|
||||
model = LlavaForConditionalGeneration(config)
|
||||
|
@ -405,10 +405,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -80,6 +80,9 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llava_next"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -102,25 +102,25 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
|
||||
if model_id == "liuhaotian/llava-v1.6-mistral-7b":
|
||||
text_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
image_token_index = 32000
|
||||
image_token_id = 32000
|
||||
elif model_id == "liuhaotian/llava-v1.6-vicuna-7b":
|
||||
text_model_id = "lmsys/vicuna-7b-v1.5"
|
||||
image_token_index = 32000
|
||||
image_token_id = 32000
|
||||
elif model_id == "liuhaotian/llava-v1.6-vicuna-13b":
|
||||
text_model_id = "lmsys/vicuna-13b-v1.5"
|
||||
image_token_index = 32000
|
||||
image_token_id = 32000
|
||||
elif model_id == "liuhaotian/llava-v1.6-34b":
|
||||
text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B"
|
||||
image_token_index = 64000
|
||||
image_token_id = 64000
|
||||
elif model_id == "lmms-lab/llama3-llava-next-8b":
|
||||
text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
image_token_index = 128256
|
||||
image_token_id = 128256
|
||||
elif model_id == "lmms-lab/llava-next-72b":
|
||||
text_model_id = "Qwen/Qwen1.5-72B-Chat"
|
||||
image_token_index = 151646
|
||||
image_token_id = 151646
|
||||
elif model_id == "lmms-lab/llava-next-110b":
|
||||
text_model_id = "Qwen/Qwen1.5-110B-Chat"
|
||||
image_token_index = 151646
|
||||
image_token_id = 151646
|
||||
|
||||
vision_model_id = data["mm_vision_tower"]
|
||||
|
||||
@ -142,7 +142,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
text_config=text_config.to_dict(),
|
||||
image_grid_pinpoints=image_processor.image_grid_pinpoints,
|
||||
use_image_newline_parameter=True,
|
||||
image_token_index=image_token_index,
|
||||
image_token_id=image_token_id,
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
@ -225,8 +225,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
if model_id == "liuhaotian/llava-v1.6-mistral-7b":
|
||||
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="llava_1_6_input_ids.pt", repo_type="dataset")
|
||||
original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
||||
original_input_ids[original_input_ids == -200] = image_token_index
|
||||
# replace -200 by image_token_id (since we use token ID = 32000 for the image token)
|
||||
original_input_ids[original_input_ids == -200] = image_token_id
|
||||
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
|
||||
|
||||
elif model_id == "liuhaotian/llava-v1.6-34b":
|
||||
@ -234,8 +234,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
repo_id="nielsr/test-image", filename="llava_1_6_34b_input_ids.pt", repo_type="dataset"
|
||||
)
|
||||
original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
# replace -200 by image_token_index
|
||||
original_input_ids[original_input_ids == -200] = image_token_index
|
||||
# replace -200 by image_token_id
|
||||
original_input_ids[original_input_ids == -200] = image_token_id
|
||||
|
||||
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
|
||||
|
||||
|
@ -624,10 +624,10 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -88,6 +88,10 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llava_next_video"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -159,18 +159,18 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
|
||||
if model_id == "lmms-lab/LLaVA-NeXT-Video-7B-32K":
|
||||
text_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
video_token_index = 32000
|
||||
image_token_index = 32001
|
||||
video_token_id = 32000
|
||||
image_token_id = 32001
|
||||
overwrite_text_config = {}
|
||||
elif model_id in ["lmms-lab/LLaVA-NeXT-Video-7B", "lmms-lab/LLaVA-NeXT-Video-7B-DPO"]:
|
||||
text_model_id = "lmsys/vicuna-7b-v1.5"
|
||||
video_token_index = 32000
|
||||
image_token_index = 32001
|
||||
video_token_id = 32000
|
||||
image_token_id = 32001
|
||||
overwrite_text_config = {"factor": 2.0, "type": "linear"}
|
||||
elif model_id in ["lmms-lab/LLaVA-NeXT-Video-34B", "lmms-lab/LLaVA-NeXT-Video-34B-DPO"]:
|
||||
text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B"
|
||||
video_token_index = 64000
|
||||
image_token_index = 64001
|
||||
video_token_id = 64000
|
||||
image_token_id = 64001
|
||||
overwrite_text_config = {}
|
||||
else:
|
||||
raise ValueError("Incorrect checkpoint referenced. Text model-id not identified!")
|
||||
@ -199,8 +199,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
|
||||
text_config=text_config,
|
||||
image_grid_pinpoints=image_processor.image_grid_pinpoints,
|
||||
use_image_newline_parameter=True,
|
||||
video_token_index=video_token_index,
|
||||
image_token_index=image_token_index,
|
||||
video_token_id=video_token_id,
|
||||
image_token_id=image_token_id,
|
||||
)
|
||||
|
||||
with init_empty_weights():
|
||||
|
@ -703,10 +703,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -725,10 +725,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -104,6 +104,10 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llava_next_video"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
@ -489,10 +493,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -511,10 +515,10 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -85,6 +85,10 @@ class LlavaOnevisionConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "llava_onevision"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -700,10 +700,10 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -724,10 +724,10 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum()
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -68,6 +68,9 @@ class Mistral3Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "mistral3"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
is_composition = True
|
||||
|
||||
|
@ -161,7 +161,7 @@ def convert_config(original_config: dict, max_position_embeddings: int = 131072)
|
||||
vision_config=new_vision_config,
|
||||
text_config=new_text_config,
|
||||
multimodal_projector_bias=adapter_bias,
|
||||
image_token_index=image_token_id,
|
||||
image_token_id=image_token_id,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
vision_feature_layer=-1,
|
||||
)
|
||||
|
@ -449,10 +449,10 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -233,10 +233,10 @@ class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -337,6 +337,9 @@ class MllamaConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "mllama"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": MllamaTextConfig, "vision_config": MllamaVisionConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -73,6 +73,9 @@ class PaliGemmaConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "paligemma"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
|
@ -80,7 +80,7 @@ DTYPES = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch
|
||||
|
||||
def get_paligemma2_config(variant: str, precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"image_token_id": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
@ -93,7 +93,7 @@ def get_paligemma2_config(variant: str, precision: str):
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
config["projection_dim"] = variant_config["hidden_size"]
|
||||
config["image_token_index"] = 257152
|
||||
config["image_token_id"] = 257152
|
||||
config["num_hidden_layers"] = variant_config["num_hidden_layers"] # For generate
|
||||
text_config = Gemma2Config.from_pretrained("google/gemma-2-2b-it").to_dict()
|
||||
sup_text_config = {
|
||||
|
@ -45,7 +45,7 @@ PALIGEMMA_VARIANTS = ["2b-test", "3b-224px", "3b-448px", "3b-896px"]
|
||||
|
||||
def get_paligemma_config(variant: str, precision: str):
|
||||
config = {
|
||||
"image_token_index": None,
|
||||
"image_token_id": None,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 1,
|
||||
@ -58,7 +58,7 @@ def get_paligemma_config(variant: str, precision: str):
|
||||
patch_size = 14
|
||||
num_image_tokens = (image_size**2) // (patch_size**2)
|
||||
|
||||
config["image_token_index"] = 257152 if variant != "2b-test" else 256000
|
||||
config["image_token_id"] = 257152 if variant != "2b-test" else 256000
|
||||
text_config = {
|
||||
"vocab_size": 257152,
|
||||
"num_hidden_layers": 18,
|
||||
|
@ -473,8 +473,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
@ -498,10 +498,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
|
@ -187,7 +187,7 @@ def convert_mistral_model(input_dir, output_dir):
|
||||
vision_config,
|
||||
text_config,
|
||||
vision_feature_layer=-1,
|
||||
image_token_index=10,
|
||||
image_token_id=10,
|
||||
vision_feature_select_strategy="full",
|
||||
image_seq_length=1,
|
||||
multimodal_projector_bias=adapter_bias,
|
||||
|
@ -458,6 +458,11 @@ class Qwen2_5OmniThinkerConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_omni_thinker"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"audio_config": Qwen2_5OmniAudioEncoderConfig,
|
||||
"vision_config": Qwen2_5OmniVisionEncoderConfig,
|
||||
@ -662,6 +667,11 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_omni_talker"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -334,9 +334,9 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
|
||||
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
||||
"""
|
||||
spatial_merge_size = self.spatial_merge_size
|
||||
image_token_id = self.config.image_token_index
|
||||
video_token_id = self.config.video_token_index
|
||||
audio_token_id = self.config.audio_token_index
|
||||
image_token_id = self.config.image_token_id
|
||||
video_token_id = self.config.video_token_id
|
||||
audio_token_id = self.config.audio_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
audio_start_token_id = self.config.audio_start_token_id
|
||||
position_id_per_seconds = self.config.position_id_per_seconds
|
||||
@ -2450,7 +2450,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_index)
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
@ -2462,7 +2462,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
@ -2474,7 +2474,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_index)
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
|
@ -443,6 +443,11 @@ class Qwen2_5OmniThinkerConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_omni_thinker"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
sub_configs = {
|
||||
"audio_config": Qwen2_5OmniAudioEncoderConfig,
|
||||
"vision_config": Qwen2_5OmniVisionEncoderConfig,
|
||||
@ -647,6 +652,11 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_5_omni_talker"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1225,9 +1235,9 @@ class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedMo
|
||||
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
||||
"""
|
||||
spatial_merge_size = self.spatial_merge_size
|
||||
image_token_id = self.config.image_token_index
|
||||
video_token_id = self.config.video_token_index
|
||||
audio_token_id = self.config.audio_token_index
|
||||
image_token_id = self.config.image_token_id
|
||||
video_token_id = self.config.video_token_id
|
||||
audio_token_id = self.config.audio_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
audio_start_token_id = self.config.audio_start_token_id
|
||||
position_id_per_seconds = self.config.position_id_per_seconds
|
||||
@ -2400,7 +2410,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_index)
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
@ -2412,7 +2422,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
@ -2424,7 +2434,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_index)
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
|
@ -156,6 +156,10 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
image_token_id (`int`, *optional*):
|
||||
Token index used as placeholder for image embeddings.
|
||||
video_token_id (`int`, *optional*):
|
||||
Token index used as placeholder for video embeddings.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2_5_VLTextModel, Qwen2_5_VLConfig
|
||||
@ -209,6 +213,8 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
rope_scaling=None,
|
||||
image_token_id=None,
|
||||
video_token_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -244,6 +250,8 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
@ -157,6 +157,9 @@ class Qwen2AudioConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_audio"
|
||||
attribute_map = {
|
||||
"audio_token_id": "audio_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "audio_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -932,7 +932,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
||||
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
||||
|
||||
# 1. Create a mask to know where special audio tokens are
|
||||
special_audio_token_mask = input_ids == self.config.audio_token_index
|
||||
special_audio_token_mask = input_ids == self.config.audio_token_id
|
||||
num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1)
|
||||
|
||||
# In case the Audio model or the Language model has been offloaded to CPU, we need to manually
|
||||
@ -942,7 +942,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
||||
input_ids = input_ids.to(target_device)
|
||||
num_audio_tokens = num_audio_tokens.to(target_device)
|
||||
batch_indices, non_audio_indices = torch.where(
|
||||
(input_ids != self.config.audio_token_index) & (attention_mask == 1)
|
||||
(input_ids != self.config.audio_token_id) & (attention_mask == 1)
|
||||
)
|
||||
|
||||
# 2. Compute the positions where text should be written
|
||||
@ -1114,7 +1114,7 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
||||
audio_features = self.multi_modal_projector(selected_audio_feature)
|
||||
|
||||
# if we have consecutive audio tokens, then it means we expanded input_ids in processing
|
||||
audio_tokens = input_ids == self.config.audio_token_index
|
||||
audio_tokens = input_ids == self.config.audio_token_id
|
||||
legacy_processing = (audio_tokens[:, :-1] & audio_tokens[:, 1:]).sum() == 0
|
||||
|
||||
if legacy_processing:
|
||||
@ -1130,14 +1130,14 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
||||
audio_features_mask = audio_features_mask < audio_output_lengths[:, None]
|
||||
audio_features = audio_features[audio_features_mask]
|
||||
|
||||
n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
|
||||
n_audio_tokens = (input_ids == self.config.audio_token_id).sum().item()
|
||||
n_audio_features = audio_features.shape[0]
|
||||
|
||||
if n_audio_tokens != n_audio_features:
|
||||
raise ValueError(
|
||||
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
|
||||
)
|
||||
special_audio_mask = (input_ids == self.config.audio_token_index).to(inputs_embeds.device)
|
||||
special_audio_mask = (input_ids == self.config.audio_token_id).to(inputs_embeds.device)
|
||||
special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
||||
|
@ -145,6 +145,10 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
image_token_id (`int`, *optional*):
|
||||
Token index used as placeholder for image embeddings.
|
||||
video_token_id (`int`, *optional*):
|
||||
Token index used as placeholder for video embeddings.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2VLTextModel, Qwen2VLConfig
|
||||
@ -198,6 +202,8 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
rope_scaling=None,
|
||||
image_token_id=None,
|
||||
video_token_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -233,6 +239,8 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
@ -72,6 +72,11 @@ class ShieldGemma2Config(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "shieldgemma2"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"boi_token_id": "boi_token_index",
|
||||
"eoi_token_id": "eoi_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -80,6 +80,10 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "video_llava"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
"video_token_id": "video_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -493,10 +493,10 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -509,10 +509,10 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum()
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0] * video_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -70,6 +70,9 @@ class VipLlavaConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "vipllava"
|
||||
attribute_map = {
|
||||
"image_token_id": "image_token_index",
|
||||
}
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
|
||||
|
||||
def __init__(
|
||||
|
@ -374,10 +374,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -224,12 +224,9 @@ class GenerationTesterMixin:
|
||||
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
|
||||
if config is not None:
|
||||
for key in [
|
||||
"image_token_index",
|
||||
"image_token_id",
|
||||
"video_token_index",
|
||||
"video_token_id",
|
||||
"vision_start_token_id",
|
||||
"audio_token_index",
|
||||
"audio_start_token_id",
|
||||
"audio_end_token_id",
|
||||
"vision_end_token_id",
|
||||
|
@ -351,8 +351,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
||||
"pad_index",
|
||||
"unk_index",
|
||||
"mask_index",
|
||||
"image_token_index", # for VLMs
|
||||
"video_token_index",
|
||||
"image_token_id", # for VLMs
|
||||
"video_token_id",
|
||||
"image_seq_length",
|
||||
"video_seq_length",
|
||||
"image_size",
|
||||
|
Loading…
Reference in New Issue
Block a user