[VLMs] support passing embeds along with pixels (#38467)

* VLMs can work with embeds now

* update more models

* fix tests

* fix copies

* fixup

* fix

* style

* unskip tests

* fix copies

* fix tests

* style

* omni modality models

* qwen models had extra indentation

* fix some other tests

* fix copies

* fix test last time

* unrelated changes revert

* we can't rely only on embeds

* delete file

* de-flake mistral3

* fix qwen models

* fix style

* fix tests

* fix copies

* deflake the test

* modular reverted by fixes, fix again

* flaky test, overwritten

* fix copies

* style
This commit is contained in:
Raushan Turganbay 2025-07-01 13:33:20 +02:00 committed by GitHub
parent 20901f1d68
commit f8b88866f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
78 changed files with 1131 additions and 1705 deletions

View File

@ -733,7 +733,9 @@ class GenerationMixin(ContinuousMixin):
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs: if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder: if model_kwargs["inputs_embeds"] is None:
model_kwargs.pop("inputs_embeds")
elif not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set( has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys() inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
) )
@ -748,10 +750,11 @@ class GenerationMixin(ContinuousMixin):
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs inputs, bos_token_id, model_kwargs=model_kwargs
) )
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
else: else:
if inputs is not None: if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token # 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)

View File

@ -1113,11 +1113,12 @@ class AriaModel(AriaPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
image_embeds = input_ids == self.config.image_token_id special_image_mask = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values=pixel_values, pixel_values=pixel_values,
pixel_mask=pixel_mask, pixel_mask=pixel_mask,

View File

@ -1446,11 +1446,12 @@ class AriaModel(LlavaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
image_embeds = input_ids == self.config.image_token_id special_image_mask = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0) n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values=pixel_values, pixel_values=pixel_values,
pixel_mask=pixel_mask, pixel_mask=pixel_mask,

View File

@ -302,14 +302,14 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -223,14 +223,14 @@ class AyaVisionModel(LlavaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -1855,6 +1855,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens", "qformer"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
@ -1971,10 +1972,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor, input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -2066,14 +2068,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
language_model_attention_mask = torch.ones( language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
@ -2146,6 +2159,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -2159,6 +2173,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices Mask to avoid performing attention on padding token indices
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings.
Returns: Returns:
captions (list): A list of strings of length batch_size * num_captions. captions (list): A list of strings of length batch_size * num_captions.
@ -2193,22 +2211,32 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if input_ids is None: if inputs_embeds is None:
start_tokens = [self.config.text_config.bos_token_id] if input_ids is None:
if getattr(self.config, "image_token_id", None) is not None: start_tokens = [self.config.text_config.bos_token_id]
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens if getattr(self.config, "image_token_id", None) is not None:
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
input_ids = input_ids.repeat(batch_size, 1) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten() special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in BLIP-2 should be done in processing. " "Expanding inputs for image tokens in BLIP-2 should be done in processing. "

View File

@ -963,25 +963,28 @@ class ChameleonModel(ChameleonPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None: if inputs_embeds is None:
raise ValueError( inputs_embeds = self.embed_tokens(input_ids)
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values) if input_ids is None:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id special_image_mask = inputs_embeds == self.get_input_embeddings()(
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() )
n_image_features = image_tokens.shape[0] * image_tokens.shape[1] special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
n_image_tokens_in_text = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel():
n_image_features = image_embeds.shape[0] * image_embeds.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
) )
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# torch.jit.trace() doesn't support cache objects in the output # torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing(): if use_cache and past_key_values is None and not torch.jit.is_tracing():

View File

@ -1537,20 +1537,26 @@ class Emu3Model(Emu3PreTrainedModel):
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
) )
if pixel_values is not None and inputs_embeds is not None: if inputs_embeds is None:
raise ValueError( inputs_embeds = self.get_input_embeddings()(input_ids)
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes) image_embeds = self.get_image_features(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_embeds = torch.cat(image_embeds, dim=0)
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model( outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,

View File

@ -1033,20 +1033,26 @@ class Emu3Model(Emu3PreTrainedModel):
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
) )
if pixel_values is not None and inputs_embeds is not None: if inputs_embeds is None:
raise ValueError( inputs_embeds = self.get_input_embeddings()(input_ids)
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes) image_embeds = self.get_image_features(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_embeds = torch.cat(image_embeds, dim=0)
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model( outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,

View File

@ -206,14 +206,22 @@ class FuyuModel(FuyuPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = self.get_image_features(image_patches)
patch_embeddings = torch.cat(patch_embeddings, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if image_patches is not None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) patch_embeddings = self.get_image_features(image_patches)
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype) patch_embeddings = torch.cat(patch_embeddings, dim=0)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
outputs = self.language_model( outputs = self.language_model(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,

View File

@ -898,9 +898,11 @@ class Gemma3Model(Gemma3PreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]

View File

@ -800,9 +800,11 @@ class Gemma3Model(PaliGemmaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]

View File

@ -1237,50 +1237,59 @@ class Glm4vModel(Glm4vPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = image_mask.sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = attention_mask attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
@ -1571,6 +1580,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1585,9 +1595,29 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
""" """
is_image = input_ids == self.config.image_start_token_id if inputs_embeds is not None:
is_video_start = input_ids == self.config.video_start_token_id is_image = (
is_video_end = input_ids == self.config.video_end_token_id inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_start = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_end = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
is_image = input_ids == self.config.image_start_token_id
is_video_start = input_ids == self.config.video_start_token_id
is_video_end = input_ids == self.config.video_end_token_id
# Cumulative sum to track if we're inside a video span # Cumulative sum to track if we're inside a video span
# We'll assume well-formed video tags (i.e. matching starts and ends) # We'll assume well-formed video tags (i.e. matching starts and ends)
@ -1623,7 +1653,9 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
def _expand_dict_for_generation_visual(dict_to_expand): def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None) image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times): def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
@ -1679,10 +1711,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
# input_ids is required for expanding visual inputs model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None: if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0) input_ids = input_ids.repeat_interleave(expand_size, dim=0)

View File

@ -1237,50 +1237,59 @@ class Glm4vModel(Qwen2_5_VLModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = image_mask.sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) video_embeds = torch.cat(video_embeds, dim=0)
n_video_tokens = (input_ids == self.config.image_token_id).sum()
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = attention_mask attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
@ -1500,6 +1509,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1514,9 +1524,29 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
""" """
is_image = input_ids == self.config.image_start_token_id if inputs_embeds is not None:
is_video_start = input_ids == self.config.video_start_token_id is_image = (
is_video_end = input_ids == self.config.video_end_token_id inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_start = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
is_video_end = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
is_image = input_ids == self.config.image_start_token_id
is_video_start = input_ids == self.config.video_start_token_id
is_video_end = input_ids == self.config.video_end_token_id
# Cumulative sum to track if we're inside a video span # Cumulative sum to track if we're inside a video span
# We'll assume well-formed video tags (i.e. matching starts and ends) # We'll assume well-formed video tags (i.e. matching starts and ends)

View File

@ -648,24 +648,27 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -339,24 +339,27 @@ class GotOcr2Model(LlavaModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features: if n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -933,10 +933,18 @@ class Idefics2Model(Idefics2PreTrainedModel):
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
""" """
special_image_token_mask = input_ids == self.image_token_id if input_ids is None:
new_inputs_embeds = inputs_embeds.clone() special_image_mask = inputs_embeds == self.get_input_embeddings()(
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
return new_inputs_embeds )
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
return inputs_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
""" """
@ -1041,25 +1049,8 @@ class Idefics2Model(Idefics2PreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and not isinstance(past_key_values, Cache):
# kept for BC (non `Cache` `past_key_values` inputs) past_key_values = DynamicCache()
return_legacy_cache = False
if use_cache:
if not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids) inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
@ -1072,7 +1063,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
elif image_hidden_states is not None: elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(
@ -1094,9 +1085,6 @@ class Idefics2Model(Idefics2PreTrainedModel):
**kwargs, **kwargs,
) )
if return_legacy_cache and use_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
return Idefics2BaseModelOutputWithPast( return Idefics2BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state, last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
@ -1304,37 +1292,11 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
**kwargs, **kwargs,
) )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step if image_hidden_states is not None or cache_position[0] != 0:
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs["input_ids"] = input_ids
if image_hidden_states is not None:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None model_inputs["pixel_attention_mask"] = None
return model_inputs return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
@staticmethod
# Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
__all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"] __all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]

View File

@ -663,15 +663,18 @@ class Idefics3Model(Idefics3PreTrainedModel):
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
""" """
special_image_token_mask = input_ids == self.image_token_id if input_ids is None:
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. special_image_mask = inputs_embeds == self.get_input_embeddings()(
new_inputs_embeds = inputs_embeds.clone() torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
# Flatten `image_hidden_states` if not flat yet )
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) special_image_mask = special_image_mask.all(-1)
# cast to the dtype of the input_embeds to support quantized models else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
new_inputs_embeds[special_image_token_mask] = image_hidden_states inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
return new_inputs_embeds return inputs_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
""" """
@ -773,11 +776,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache: past_key_values = DynamicCache()
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
@ -790,7 +790,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
elif image_hidden_states is not None: elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None: if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(
@ -1042,28 +1042,11 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
**kwargs, **kwargs,
) )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step if image_hidden_states is not None or cache_position[0] != 0:
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs["input_ids"] = input_ids
if image_hidden_states is not None:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None model_inputs["pixel_attention_mask"] = None
return model_inputs return model_inputs
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
__all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"] __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]

View File

@ -1255,6 +1255,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@ -1328,12 +1329,20 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
# step 3: use the language model, conditioned on the query outputs and the prompt # step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output) language_model_inputs = self.language_projection(query_output)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
if attention_mask is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
attention_mask = torch.ones_like(input_ids) special_image_mask = input_ids == self.config.image_token_id
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
else:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten() language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
@ -1513,6 +1522,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -1604,15 +1614,26 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten() special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. " "Expanding inputs for image tokens in InstructBLIP should be done in processing. "
@ -1673,6 +1694,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
qformer_attention_mask: Optional[torch.LongTensor] = None, qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -1690,6 +1712,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings. Whether to interpolate the positional encoding of the image embeddings.
@ -1712,23 +1736,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if input_ids is None: if inputs_embeds is None:
start_tokens = [self.config.text_config.bos_token_id] if input_ids is None:
if getattr(self.config, "image_token_id", None) is not None: start_tokens = [self.config.text_config.bos_token_id]
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens if getattr(self.config, "image_token_id", None) is not None:
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
input_ids = input_ids.repeat(batch_size, 1) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
# if the model already has "image_token_id" then the input is expanded to account for image embeds # if the model already has "image_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "image_token_id", None) is not None: if getattr(self.config, "image_token_id", None) is not None:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in InstructBLIP should be done in processing. " "Expanding inputs for image tokens in InstructBLIP should be done in processing. "

View File

@ -1251,6 +1251,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@ -1334,12 +1335,20 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
if attention_mask is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
attention_mask = torch.ones_like(input_ids) special_image_mask = input_ids == self.config.video_token_id
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
else:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten() language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
@ -1485,6 +1494,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -1599,15 +1609,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
@ -1668,6 +1689,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
qformer_attention_mask: Optional[torch.LongTensor] = None, qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -1685,6 +1707,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings. Whether to interpolate the positional encoding of the image embeddings.
@ -1708,23 +1732,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if input_ids is None: if inputs_embeds is None:
start_tokens = [self.config.text_config.bos_token_id] if input_ids is None:
if getattr(self.config, "video_token_id", None) is not None: start_tokens = [self.config.text_config.bos_token_id]
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens if getattr(self.config, "video_token_id", None) is not None:
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
input_ids = input_ids.repeat(batch_size, 1) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "

View File

@ -202,6 +202,7 @@ class InstructBlipVideoModel(InstructBlipModel):
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
@ -255,12 +256,20 @@ class InstructBlipVideoModel(InstructBlipModel):
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
if attention_mask is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
attention_mask = torch.ones_like(input_ids) special_image_mask = input_ids == self.config.video_token_id
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
else:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds[special_image_mask] = language_model_inputs.flatten() language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
if self.config.use_decoder_only_language_model: if self.config.use_decoder_only_language_model:
outputs = self.language_model( outputs = self.language_model(
@ -372,6 +381,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@ -451,15 +461,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
@ -520,6 +541,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
qformer_attention_mask: Optional[torch.LongTensor] = None, qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
**generate_kwargs, **generate_kwargs,
) -> torch.LongTensor: ) -> torch.LongTensor:
@ -537,6 +559,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the positional encoding of the image embeddings. Whether to interpolate the positional encoding of the image embeddings.
@ -560,23 +584,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
) )
if input_ids is None: if inputs_embeds is None:
start_tokens = [self.config.text_config.bos_token_id] if input_ids is None:
if getattr(self.config, "video_token_id", None) is not None: start_tokens = [self.config.text_config.bos_token_id]
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens if getattr(self.config, "video_token_id", None) is not None:
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
input_ids = input_ids.repeat(batch_size, 1) input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
input_ids = input_ids.repeat(batch_size, 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids)
# if the model already has "video_token_id" then the input is expanded to account for image embeds # if the model already has "video_token_id" then the input is expanded to account for image embeds
# otherwise we expand manually by concatenating # otherwise we expand manually by concatenating
if getattr(self.config, "video_token_id", None) is not None: if getattr(self.config, "video_token_id", None) is not None:
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) if input_ids is None:
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
else: else:
logger.warning_once( logger.warning_once(
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. " "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "

View File

@ -710,14 +710,14 @@ class InternVLModel(InternVLPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -641,14 +641,14 @@ class InternVLModel(LlavaModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -1102,23 +1102,21 @@ class JanusModel(JanusPreTrainedModel):
) )
use_cache = False use_cache = False
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
if input_ids is None:
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_attention_mask = image_attention_mask.all(-1)
else:
image_attention_mask = input_ids == self.config.image_token_id
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values) image_embeds = self.get_image_features(pixel_values)
image_attention_mask = input_ids == self.config.image_token_id image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
embed_dim = inputs_embeds.shape[-1]
image_features = image_embeds.reshape(-1, embed_dim)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

View File

@ -955,23 +955,21 @@ class JanusModel(JanusPreTrainedModel):
) )
use_cache = False use_cache = False
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
if input_ids is None:
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_attention_mask = image_attention_mask.all(-1)
else:
image_attention_mask = input_ids == self.config.image_token_id
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values) image_embeds = self.get_image_features(pixel_values)
image_attention_mask = input_ids == self.config.image_token_id image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
embed_dim = inputs_embeds.shape[-1]
image_features = image_embeds.reshape(-1, embed_dim)
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)

View File

@ -1467,25 +1467,19 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
image_embeds_position_mask=None, image_embeds_position_mask=None,
past_key_values=None, past_key_values=None,
attention_mask=None, attention_mask=None,
inputs_embeds=None,
use_cache=None, use_cache=None,
cache_position=None, cache_position=None,
**model_kwargs, **model_kwargs,
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# Kosmos2 has offset for position ids, so we need to create them correctly
position_ids = create_position_ids_from_input_ids(
input_ids,
padding_idx=self.config.pad_token_id,
past_key_values_length=0,
)
if past_key_values is not None: if past_key_values is not None:
image_embeds = None image_embeds = None
image_embeds_position_mask = None image_embeds_position_mask = None
# appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation) # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation)
elif image_embeds_position_mask is not None: elif image_embeds_position_mask is not None:
batch_size, seq_len = input_ids.size() batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size()
mask_len = image_embeds_position_mask.size()[-1] mask_len = image_embeds_position_mask.size()[-1]
image_embeds_position_mask = torch.cat( image_embeds_position_mask = torch.cat(
( (
@ -1501,11 +1495,13 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
attention_mask=attention_mask, attention_mask=attention_mask,
image_embeds=image_embeds, image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask, image_embeds_position_mask=image_embeds_position_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
position_ids=position_ids,
cache_position=cache_position, cache_position=cache_position,
**model_kwargs, **model_kwargs,
) )
# Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
model_inputs.pop("position_ids", None)
return model_inputs return model_inputs
@ -1875,6 +1871,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
): ):
# in order to allow `inputs` argument (as in `GenerationMixin`) # in order to allow `inputs` argument (as in `GenerationMixin`)
@ -1900,6 +1897,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
attention_mask=attention_mask, attention_mask=attention_mask,
image_embeds=image_embeds, image_embeds=image_embeds,
image_embeds_position_mask=image_embeds_position_mask, image_embeds_position_mask=image_embeds_position_mask,
inputs_embeds=inputs_embeds,
**kwargs, **kwargs,
) )

View File

@ -1358,27 +1358,28 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes, image_sizes=image_sizes,
) )
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1)) vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat) projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
final_mask = special_image_mask.to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
final_mask_1d = final_mask[..., 0].reshape(-1) n_image_tokens = (special_image_mask).sum()
num_tokens_to_fill = final_mask_1d.sum() special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if num_tokens_to_fill != projected_vision_flat.size(0): if n_image_tokens != projected_vision_flat.size(0):
raise ValueError( raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " f"Mismatch: final_mask wants {n_image_tokens} embeddings, "
f"but multi_modal_projector returned {projected_vision_flat.size(0)}" f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
) )
projected_vision_flat = projected_vision_flat.to(inputs_embeds.device, inputs_embeds.dtype)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -284,14 +284,14 @@ class LlavaModel(LlavaPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -468,11 +468,6 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -485,10 +480,18 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -519,12 +519,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -537,10 +531,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -559,10 +561,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
video_features = torch.cat(video_features, dim=0) video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -440,12 +440,6 @@ class LlavaNextVideoModel(LlavaNextModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -458,10 +452,18 @@ class LlavaNextVideoModel(LlavaNextModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -480,10 +482,18 @@ class LlavaNextVideoModel(LlavaNextModel):
video_features = torch.cat(video_features, dim=0) video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -551,12 +551,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -571,10 +565,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -595,10 +597,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
video_features = torch.cat((video_features, image_newline), dim=1) video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1) video_features = video_features.flatten(0, 1)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_video_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_video_mask).sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -535,12 +535,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -555,10 +549,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] n_image_features = image_features.shape[0]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -579,10 +581,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
video_features = torch.cat((video_features, image_newline), dim=1) video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1) video_features = video_features.flatten(0, 1)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_video_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_video_mask).sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] n_video_features = video_features.shape[0]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -308,11 +308,6 @@ class Mistral3Model(Mistral3PreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -324,10 +319,18 @@ class Mistral3Model(Mistral3PreTrainedModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -204,11 +204,6 @@ class Mistral3Model(LlavaModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -220,10 +215,18 @@ class Mistral3Model(LlavaModel):
) )
image_features = torch.cat(image_features, dim=0) image_features = torch.cat(image_features, dim=0)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -331,9 +331,11 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
special_image_mask = inputs_embeds == self.get_input_embeddings()( special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
special_image_mask = special_image_mask.all(-1)
else: else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]

View File

@ -1903,43 +1903,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video # 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage if input_features is not None:
if input_features is not None: audio_features = self.get_audio_features(
audio_features = self.get_audio_features( input_features,
input_features, feature_attention_mask=feature_attention_mask,
feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths,
audio_feature_lengths=audio_feature_lengths, )
if input_ids is None:
audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
audio_mask = ( audio_mask = audio_mask.all(-1)
(input_ids == self.config.audio_token_id) else:
.unsqueeze(-1) audio_mask = input_ids == self.config.audio_token_id
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None: audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values, image_grid_thw) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
image_mask = ( inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
video_mask = ( if input_ids is None:
(input_ids == self.config.video_token_id) image_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = image_mask.all(-1)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) else:
image_mask = input_ids == self.config.image_token_id
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if feature_attention_mask is not None: if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)

View File

@ -2350,43 +2350,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video # 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage if input_features is not None:
if input_features is not None: audio_features = self.get_audio_features(
audio_features = self.get_audio_features( input_features,
input_features, feature_attention_mask=feature_attention_mask,
feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths,
audio_feature_lengths=audio_feature_lengths, )
if input_ids is None:
audio_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
) )
audio_mask = ( audio_mask = audio_mask.all(-1)
(input_ids == self.config.audio_token_id) else:
.unsqueeze(-1) audio_mask = input_ids == self.config.audio_token_id
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None: audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = self.get_image_features(pixel_values, image_grid_thw) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
image_mask = ( inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
video_mask = ( if input_ids is None:
(input_ids == self.config.video_token_id) image_mask = inputs_embeds == self.get_input_embeddings()(
.unsqueeze(-1) torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = image_mask.all(-1)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) else:
image_mask = input_ids == self.config.image_token_id
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if feature_attention_mask is not None: if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)

View File

@ -1245,41 +1245,51 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id if pixel_values is not None:
mask_unsqueezed = mask.unsqueeze(-1) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_embeds = torch.cat(image_embeds, dim=0)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) if input_ids is None:
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
if pixel_values_videos is not None: n_image_tokens = (image_mask).sum()
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
video_embeds = torch.cat(video_embeds, dim=0) n_image_features = image_embeds.shape[0]
n_video_tokens = (input_ids == self.config.video_token_id).sum() if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
n_video_features = video_embeds.shape[0] raise ValueError(
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
raise ValueError( )
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
mask = input_ids == self.config.video_token_id if pixel_values_videos is not None:
mask_unsqueezed = mask.unsqueeze(-1) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_embeds = torch.cat(video_embeds, dim=0)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) if input_ids is None:
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = ( attention_mask_tensor = (
@ -1586,6 +1596,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1603,10 +1614,31 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id vision_start_token_id = self.config.vision_start_token_id
vision_start_mask = input_ids == vision_start_token_id if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
image_nums = torch.sum(vision_first_mask & image_mask, dim=1) image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
@ -1632,7 +1664,9 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
def _expand_dict_for_generation_visual(dict_to_expand): def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None) image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times): def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
@ -1688,10 +1722,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
# input_ids is required for expanding visual inputs model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None: if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0) input_ids = input_ids.repeat_interleave(expand_size, dim=0)

View File

@ -609,41 +609,51 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id if pixel_values is not None:
mask_unsqueezed = mask.unsqueeze(-1) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_embeds = torch.cat(image_embeds, dim=0)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) if input_ids is None:
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask.all(-1)
else:
image_mask = input_ids == self.config.image_token_id
if pixel_values_videos is not None: n_image_tokens = (image_mask).sum()
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
video_embeds = torch.cat(video_embeds, dim=0) n_image_features = image_embeds.shape[0]
n_video_tokens = (input_ids == self.config.video_token_id).sum() if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
n_video_features = video_embeds.shape[0] raise ValueError(
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
raise ValueError( )
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
mask = input_ids == self.config.video_token_id if pixel_values_videos is not None:
mask_unsqueezed = mask.unsqueeze(-1) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_embeds = torch.cat(video_embeds, dim=0)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) if input_ids is None:
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
video_mask = video_mask.all(-1)
else:
video_mask = input_ids == self.config.video_token_id
n_video_tokens = (video_mask).sum()
n_video_features = video_embeds.shape[0]
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = ( attention_mask_tensor = (

View File

@ -1182,41 +1182,52 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None: if pixel_values is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0) image_embeds = torch.cat(image_embeds, dim=0)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0] if input_ids is None:
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: image_mask = inputs_embeds == self.get_input_embeddings()(
raise ValueError( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
) )
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = image_mask.all(-1)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) else:
image_mask = input_ids == self.config.image_token_id
n_image_tokens = image_mask.sum()
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0)
if input_ids is None:
video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_video_tokens = (video_mask).sum(dim=1).sum(dim=0)[0]
else:
video_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
video_mask = video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_video_tokens = (input_ids == self.config.image_token_id).sum()
n_video_features = video_embeds.shape[0]
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = ( attention_mask_tensor = (
@ -1480,6 +1491,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def _get_image_nums_and_video_nums( def _get_image_nums_and_video_nums(
self, self,
input_ids: Optional[torch.LongTensor], input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Get the number of images and videos for each sample to calculate the separation length of the sample tensor. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
@ -1497,10 +1509,31 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
video_token_id = self.config.video_token_id video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id vision_start_token_id = self.config.vision_start_token_id
vision_start_mask = input_ids == vision_start_token_id if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
image_nums = torch.sum(vision_first_mask & image_mask, dim=1) image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
@ -1526,7 +1559,9 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
def _expand_dict_for_generation_visual(dict_to_expand): def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None) image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times): def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths) samples = torch.split(x, lengths)
@ -1582,10 +1617,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
# input_ids is required for expanding visual inputs model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None: if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0) input_ids = input_ids.repeat_interleave(expand_size, dim=0)

View File

@ -595,7 +595,14 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
""" """
_, patch_size, _ = image_hidden_states.shape _, patch_size, _ = image_hidden_states.shape
image_mask = input_ids == self.image_token_id if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask[..., 0] # slice off the hidden dim
else:
image_mask = input_ids == self.config.image_token_id
num_image_tokens = image_mask.sum(dim=1) num_image_tokens = image_mask.sum(dim=1)
if not torch.all(num_image_tokens % patch_size == 0): if not torch.all(num_image_tokens % patch_size == 0):
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.") raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
@ -717,14 +724,8 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache: past_key_values = DynamicCache()
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
@ -732,12 +733,13 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
# START VISUAL INPUTS INTEGRATION # START VISUAL INPUTS INTEGRATION
if pixel_values is not None and image_hidden_states is not None: if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if inputs_embeds is not None and image_hidden_states is not None: if pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(
@ -996,27 +998,11 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
**kwargs, **kwargs,
) )
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step if image_hidden_states is not None or cache_position[0] != 0:
# but IDEFICS requires both ids and embeds to be present
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs["input_ids"] = input_ids
if image_hidden_states is not None:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None
model_inputs["pixel_attention_mask"] = None model_inputs["pixel_attention_mask"] = None
return model_inputs return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
return model_kwargs
__all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"] __all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"]

View File

@ -180,7 +180,14 @@ class SmolVLMModel(Idefics3Model):
): ):
_, patch_size, _ = image_hidden_states.shape _, patch_size, _ = image_hidden_states.shape
image_mask = input_ids == self.image_token_id if input_ids is None:
image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
image_mask = image_mask[..., 0] # slice off the hidden dim
else:
image_mask = input_ids == self.config.image_token_id
num_image_tokens = image_mask.sum(dim=1) num_image_tokens = image_mask.sum(dim=1)
if not torch.all(num_image_tokens % patch_size == 0): if not torch.all(num_image_tokens % patch_size == 0):
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.") raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
@ -296,14 +303,8 @@ class SmolVLMModel(Idefics3Model):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_seen_tokens = 0 if use_cache and past_key_values is None:
if use_cache: past_key_values = DynamicCache()
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device) inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
@ -311,12 +312,13 @@ class SmolVLMModel(Idefics3Model):
# START VISUAL INPUTS INTEGRATION # START VISUAL INPUTS INTEGRATION
if pixel_values is not None and image_hidden_states is not None: if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
if inputs_embeds is not None and image_hidden_states is not None: if pixel_values is not None:
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
if image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images # When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist # that simply don't exist
inputs_embeds = self.inputs_merger( inputs_embeds = self.inputs_merger(

View File

@ -328,12 +328,6 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
"time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -343,10 +337,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
vision_feature_layer=vision_feature_layer, vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
@ -359,10 +361,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
) )
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.video_token_id
n_video_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_features.shape[0] * video_features.shape[1] n_video_features = video_features.shape[0] * video_features.shape[1]
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"

View File

@ -233,11 +233,6 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -246,10 +241,18 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -136,11 +136,6 @@ class VipLlavaModel(LlavaModel):
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
@ -149,10 +144,18 @@ class VipLlavaModel(LlavaModel):
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
) )
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) if input_ids is None:
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1] n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"

View File

@ -118,27 +118,6 @@ from unittest.mock import patch
from transformers.utils import is_sklearn_available from transformers.utils import is_sklearn_available
# TODO: raushan remove this when VLMs start accepting input embeds
VLM_CLASS_NAMES = [
"llava",
"idefics2",
"idefics3",
"mllama",
"paligemma",
"emu3",
"gotocr2",
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"janus",
"gemma3",
"mistral3",
"chameleon",
"internvl",
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`
]
class GenerationTesterMixin: class GenerationTesterMixin:
input_name = "input_ids" input_name = "input_ids"
model_tester = None model_tester = None
@ -1228,7 +1207,23 @@ class GenerationTesterMixin:
"blip2", # overridden `generate()` "blip2", # overridden `generate()`
"instructblip", "instructblip",
"instructblipvideo", "instructblipvideo",
*VLM_CLASS_NAMES, # shouldn't suggest image tokens # All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
"llava",
"idefics2",
"idefics3",
"mllama",
"paligemma",
"emu3",
"gotocr2",
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"janus",
"gemma3",
"mistral3",
"chameleon",
"internvl",
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`,
] ]
): ):
self.skipTest(reason="May fix in the future: need model-specific fixes") self.skipTest(reason="May fix in the future: need model-specific fixes")
@ -1641,6 +1636,58 @@ class GenerationTesterMixin:
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
@pytest.mark.generate
def test_generate_from_random_inputs_embeds(self):
"""
Text-only: Tests that different `inputs_embeds` generate different outputs in models with `main_input=="input_ids"`.
Some models have 'images' as main input and thus can't generate with random text embeddings.
See `test_generate_from_inputs_embeds` for more general checks.
"""
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if config.is_encoder_decoder:
continue
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
# No easy fix, let's skip the test for now
has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi"]
)
if model_class.main_input_name != "input_ids" or has_complex_embeds_computation:
self.skipTest(
"The model's main input name in not `input_ids` and we need kwargs from input dict as well."
)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"do_sample": False,
"max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens
}
input_ids = inputs_dict.pop("input_ids")
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, **generation_kwargs)
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids=input_ids, inputs_embeds=random_embeds, **generation_kwargs
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
@pytest.mark.generate @pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)]) @parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams): def test_generate_from_inputs_embeds(self, _, num_beams):
@ -1662,34 +1709,22 @@ class GenerationTesterMixin:
continue continue
# There are a few exception patterns in this test: # There are a few exception patterns in this test:
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed # 1 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"])
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
# than calling the embedding layer with `input_ids`. Subcases of this exception: # than calling the embedding layer with `input_ids`. Subcases of this exception:
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag) # 1.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
if hasattr(config, "scale_embedding"): if hasattr(config, "scale_embedding"):
config.scale_embedding = False config.scale_embedding = False
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
# HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive; # HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive;
# this is similar to VLMs and should likely be standardized for similar audio models in the future, # this is similar to VLMs and should likely be standardized for similar audio models in the future,
# then made generic here. # then made generic here.
if "granitespeech" in model_class.__name__.lower(): if "granitespeech" in model_class.__name__.lower():
inputs_dict.pop("input_features", None) inputs_dict.pop("input_features", None)
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` # 1.B - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
has_complex_embeds_computation = any( has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi"] model_name in model_class.__name__.lower() for model_name in ["moshi"]
) )
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, # 2 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input. # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
missing_attention_mask = "attention_mask" not in inputs_dict missing_attention_mask = "attention_mask" not in inputs_dict
@ -1702,31 +1737,23 @@ class GenerationTesterMixin:
"do_sample": False, "do_sample": False,
"max_new_tokens": 5, "max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens "min_new_tokens": 5, # generate exactly 5 tokens
"use_cache": True,
} }
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict) outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5)) self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output). # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same. # The output of the two calls should be the same.
inputs_embeds = model.get_input_embeddings()(input_ids) inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate( outputs_from_embeds = model.generate(
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
) )
if not has_complex_embeds_computation: if not has_complex_embeds_computation:
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same # be the same
if not (requires_inputs_ids or missing_attention_mask): if not missing_attention_mask:
outputs_from_embeds_wo_ids = model.generate( outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
) )
@ -1753,17 +1780,6 @@ class GenerationTesterMixin:
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
self.skipTest(reason="This model does not support `inputs_embeds` in generation") self.skipTest(reason="This model does not support `inputs_embeds` in generation")
# Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
input_ids = inputs_dict.pop("input_ids") input_ids = inputs_dict.pop("input_ids")
model.config.use_cache = True model.config.use_cache = True
@ -1925,14 +1941,6 @@ class GenerationTesterMixin:
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
input_ids = inputs_dict.pop("input_ids") input_ids = inputs_dict.pop("input_ids")
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1

View File

@ -189,49 +189,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
self.model_tester = AriaVisionText2TextModelTester(self) self.model_tester = AriaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip( @unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )
@ -270,14 +227,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
def test_dola_decoding_sample(self): def test_dola_decoding_sample(self):
pass pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_0_greedy(self):
pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
@unittest.skip(reason="Dynamic control flow due to MoE") @unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_with_static_cache(self): def test_generate_with_static_cache(self):
pass pass

View File

@ -62,7 +62,7 @@ class AyaVisionVisionText2TextModelTester:
bos_token_id=0, bos_token_id=0,
eos_token_id=0, eos_token_id=0,
pad_token_id=0, pad_token_id=0,
image_token_index=1, image_token_index=2,
num_channels=3, num_channels=3,
image_size=64, image_size=64,
model_type="aya_vision", model_type="aya_vision",
@ -183,49 +183,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip("Failing because of unique cache (HybridCache)") @unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs): def test_model_outputs_equivalence(self, **kwargs):
pass pass
@ -285,10 +242,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)") @unittest.skip("Failing because of unique cache (HybridCache)")
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass

View File

@ -20,7 +20,6 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
import requests import requests
from parameterized import parameterized
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
@ -674,15 +673,6 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
# They should result in very similar logits # They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py # this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
class Blip2TextModelTester: class Blip2TextModelTester:

View File

@ -355,49 +355,6 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit
pixel_values = torch.cat([pixel_values, pixel_values], dim=0) pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values) _ = model(input_ids=input_ids, pixel_values=pixel_values)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@require_torch @require_torch
class ChameleonIntegrationTest(unittest.TestCase): class ChameleonIntegrationTest(unittest.TestCase):

View File

@ -189,50 +189,6 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester = ColPaliForRetrievalModelTester(self) self.model_tester = ColPaliForRetrievalModelTester(self)
self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@slow @slow
@require_vision @require_vision
def test_colpali_forward_inputs(self): def test_colpali_forward_inputs(self):

View File

@ -331,49 +331,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip( @unittest.skip(
"Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module" "Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module"
) )

View File

@ -131,10 +131,6 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.") @unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
def test_eager_matches_fa2_generate(self): def test_eager_matches_fa2_generate(self):
pass pass

View File

@ -13,12 +13,10 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch GLM-4.1V model.""" """Testing suite for the PyTorch GLM-4.1V model."""
import copy
import gc import gc
import unittest import unittest
import requests import requests
from parameterized import parameterized
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
@ -237,11 +235,6 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
def test_sdpa_can_dispatch_on_flash(self): def test_sdpa_can_dispatch_on_flash(self):
pass pass
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@unittest.skip("Cannot generate from inputs embeds with pixel values")
def test_generate_from_inputs_embeds(self):
pass
@unittest.skip(reason="Size mismatch") @unittest.skip(reason="Size mismatch")
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass
@ -250,34 +243,11 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
def test_model_is_small(self): def test_model_is_small(self):
pass pass
@unittest.skip("Cannot generate from inputs embeds with pixel values") @unittest.skip("Error with compilation")
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test # RoPE index doesn't match when using embeddings
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
# TODO: @raushan
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["image_grid_thw"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)[0]
def test_inputs_embeds_matches_input_ids(self): def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -51,9 +51,6 @@ class GotOcr2VisionText2TextModelTester:
num_channels=3, num_channels=3,
ignore_index=-100, ignore_index=-100,
image_size=64, image_size=64,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_index=1, image_token_index=1,
model_type="got_ocr2", model_type="got_ocr2",
is_training=True, is_training=True,
@ -71,6 +68,9 @@ class GotOcr2VisionText2TextModelTester:
"rope_theta": 10000, "rope_theta": 10000,
"mlp_ratio": 4, "mlp_ratio": 4,
"tie_word_embeddings": True, "tie_word_embeddings": True,
"bos_token_id": 2,
"eos_token_id": 3,
"pad_token_id": 4,
}, },
vision_config={ vision_config={
"num_hidden_layers": 2, "num_hidden_layers": 2,
@ -85,9 +85,9 @@ class GotOcr2VisionText2TextModelTester:
): ):
self.parent = parent self.parent = parent
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.bos_token_id = bos_token_id self.bos_token_id = text_config["bos_token_id"]
self.eos_token_id = eos_token_id self.eos_token_id = text_config["eos_token_id"]
self.pad_token_id = pad_token_id self.pad_token_id = text_config["pad_token_id"]
self.image_token_index = image_token_index self.image_token_index = image_token_index
self.model_type = model_type self.model_type = model_type
self.text_config = text_config self.text_config = text_config
@ -109,9 +109,6 @@ class GotOcr2VisionText2TextModelTester:
text_config=self.text_config, text_config=self.text_config,
vision_config=self.vision_config, vision_config=self.vision_config,
model_type=self.model_type, model_type=self.model_type,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
image_token_index=self.image_token_index, image_token_index=self.image_token_index,
) )
@ -127,7 +124,6 @@ class GotOcr2VisionText2TextModelTester:
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
# input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.image_token_index] = self.pad_token_id input_ids[input_ids == self.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = self.image_token_index input_ids[:, : self.num_image_tokens] = self.image_token_index
@ -181,55 +177,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip( @unittest.skip(
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format" reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
) )

View File

@ -315,13 +315,6 @@ class IdeficsModelTester:
def prepare_pixel_values(self): def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip(reason="Idefics has a hard requirement on SDPA, skipping this test")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@require_torch @require_torch
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase): class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
@ -611,6 +604,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMi
def test_sdpa_can_dispatch_non_composite_models(self): def test_sdpa_can_dispatch_non_composite_models(self):
pass pass
@unittest.skip(reason="Idefics can't do text-only inference")
def test_generate_from_random_inputs_embeds(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@require_torch @require_torch
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase): class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
@ -899,6 +898,12 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
def test_generation_tester_mixin_inheritance(self): def test_generation_tester_mixin_inheritance(self):
pass pass
@unittest.skip(reason="Idefics can't do text-only inference")
def test_generate_from_random_inputs_embeds(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@require_torch @require_torch
@require_vision @require_vision

View File

@ -108,6 +108,7 @@ class Idefics2VisionText2TextModelTester:
image_token_id=99, image_token_id=99,
): ):
self.parent = parent self.parent = parent
self.pad_token_id = text_config["pad_token_id"]
self.is_training = is_training self.is_training = is_training
self.batch_size = batch_size self.batch_size = batch_size
self.num_images = num_images self.num_images = num_images
@ -158,6 +159,7 @@ class Idefics2VisionText2TextModelTester:
# For simplicity just set the last n tokens to the image token # For simplicity just set the last n tokens to the image token
n_image_tokens_per_batch = self.num_images * self.perceiver_config["resampler_n_latents"] n_image_tokens_per_batch = self.num_images * self.perceiver_config["resampler_n_latents"]
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id
attention_mask = input_ids.ne(1).to(torch_device) attention_mask = input_ids.ne(1).to(torch_device)
inputs_dict = { inputs_dict = {

View File

@ -96,6 +96,7 @@ class Idefics3VisionText2TextModelTester:
image_token_id=57, image_token_id=57,
): ):
self.parent = parent self.parent = parent
self.pad_token_id = text_config["pad_token_id"]
self.is_training = is_training self.is_training = is_training
self.batch_size = batch_size self.batch_size = batch_size
self.num_images = num_images self.num_images = num_images
@ -148,6 +149,7 @@ class Idefics3VisionText2TextModelTester:
# For simplicity just set the last n tokens to the image token # For simplicity just set the last n tokens to the image token
n_image_tokens_per_batch = self.seq_length n_image_tokens_per_batch = self.seq_length
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id
attention_mask = input_ids.ne(1).to(torch_device) attention_mask = input_ids.ne(1).to(torch_device)
inputs_dict = { inputs_dict = {

View File

@ -20,7 +20,6 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
import requests import requests
from parameterized import parameterized
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
@ -522,12 +521,6 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
def test_model_get_set_embeddings(self): def test_model_get_set_embeddings(self):
pass pass
@unittest.skip(
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -656,13 +649,6 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
# They should result in very similar logits # They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@unittest.skip(
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
)
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@require_torch_sdpa @require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self): def test_sdpa_can_dispatch_composite_models(self):
""" """

View File

@ -20,7 +20,6 @@ import unittest
import numpy as np import numpy as np
import pytest import pytest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from parameterized import parameterized
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
@ -535,12 +534,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@unittest.skip(
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -669,13 +662,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
# They should result in very similar logits # They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
@unittest.skip(
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
)
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@require_torch_sdpa @require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self): def test_sdpa_can_dispatch_composite_models(self):
""" """

View File

@ -63,9 +63,6 @@ class InternVLVisionText2TextModelTester:
image_seq_length=64, image_seq_length=64,
vision_feature_layer=-1, vision_feature_layer=-1,
ignore_index=-100, ignore_index=-100,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_id=1, image_token_id=1,
num_channels=3, num_channels=3,
image_size=64, image_size=64,
@ -85,9 +82,9 @@ class InternVLVisionText2TextModelTester:
"rope_theta": 10000, "rope_theta": 10000,
"mlp_ratio": 4, "mlp_ratio": 4,
"tie_word_embeddings": True, "tie_word_embeddings": True,
"bos_token_id": 0, "bos_token_id": 3,
"eos_token_id": 0, "eos_token_id": 4,
"pad_token_id": 0, "pad_token_id": 5,
}, },
vision_config={ vision_config={
"hidden_size": 32, "hidden_size": 32,
@ -103,9 +100,9 @@ class InternVLVisionText2TextModelTester:
): ):
self.parent = parent self.parent = parent
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.bos_token_id = bos_token_id self.bos_token_id = text_config["bos_token_id"]
self.eos_token_id = eos_token_id self.eos_token_id = text_config["eos_token_id"]
self.pad_token_id = pad_token_id self.pad_token_id = text_config["pad_token_id"]
self.image_token_id = image_token_id self.image_token_id = image_token_id
self.model_type = model_type self.model_type = model_type
self.text_config = text_config self.text_config = text_config
@ -128,9 +125,6 @@ class InternVLVisionText2TextModelTester:
text_config=self.text_config, text_config=self.text_config,
vision_config=self.vision_config, vision_config=self.vision_config,
model_type=self.model_type, model_type=self.model_type,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
image_token_id=self.image_token_id, image_token_id=self.image_token_id,
image_seq_length=self.image_seq_length, image_seq_length=self.image_seq_length,
vision_feature_layer=self.vision_feature_layer, vision_feature_layer=self.vision_feature_layer,
@ -148,7 +142,6 @@ class InternVLVisionText2TextModelTester:
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
# input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id input_ids[:, : self.image_seq_length] = self.image_token_id
@ -222,49 +215,6 @@ class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip(reason="Compile not yet supported because in LLava models") @unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_compile_dynamic(self):
pass pass

View File

@ -153,6 +153,7 @@ class JanusVisionText2TextModelTester:
text_config=self.text_config, text_config=self.text_config,
vision_config=self.vision_config, vision_config=self.vision_config,
vq_config=self.get_vq_config(), vq_config=self.get_vq_config(),
image_token_id=self.image_token_index,
) )
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
@ -200,50 +201,6 @@ class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, uni
self.model_tester = JanusVisionText2TextModelTester(self) self.model_tester = JanusVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=JanusConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=JanusConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["generation_mode"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# Overwrite inputs_embeds tests because we need to delete "pixel values" for VLMs.
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["generation_mode"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_sdpa_can_dispatch_composite_models(self): def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -457,14 +457,6 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@unittest.skip(
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten"
)
def test_generate_from_inputs_embeds(self):
pass
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa @require_torch_sdpa
@unittest.skip("KOSMOS-2 doesn't support padding") @unittest.skip("KOSMOS-2 doesn't support padding")
@ -613,6 +605,53 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# (Even with this call, there are still memory leak by ~0.04MB) # (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry() self.clear_torch_jit_class_registry()
@pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
# NOTE: overwritten because Kosmos with ids prepares position ids differently from embeds
# If the model get ids, all pad tokens are masked from position ids. That is not possible with embeds
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
config.is_decoder = True
# Skip models without explicit support
model = model_class(config).to(torch_device).eval()
# Traditional way of generating text
input_ids = inputs_dict.pop("input_ids")
input_ids[input_ids == config.get_text_config().pad_token_id] = 0
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"num_beams": num_beams,
"do_sample": False,
"max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens
"use_cache": True,
}
outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same.
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():

View File

@ -196,49 +196,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """
Tests that VLMs through an error with explicit message saying what is wrong Tests that VLMs through an error with explicit message saying what is wrong

View File

@ -222,49 +222,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """
Tests that VLMs through an error with explicit message saying what is wrong Tests that VLMs through an error with explicit message saying what is wrong

View File

@ -86,7 +86,7 @@ class LlavaNextVideoVisionText2TextModelTester:
"initializer_range": 0.02, "initializer_range": 0.02,
"num_labels": 3, "num_labels": 3,
"num_choices": 4, "num_choices": 4,
"pad_token_id": 2, "pad_token_id": 3,
}, },
is_training=True, is_training=True,
vision_config={ vision_config={
@ -234,51 +234,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["pixel_values_videos"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["pixel_values_videos"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """
Tests that VLMs through an error with explicit message saying what is wrong Tests that VLMs through an error with explicit message saying what is wrong

View File

@ -230,49 +230,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_odd_sized_image(self): def test_odd_sized_image(self):
# prepare model configuration # prepare model configuration
config = self.model_tester.get_config() config = self.model_tester.get_config()

View File

@ -57,9 +57,6 @@ class Mistral3VisionText2TextModelTester:
image_seq_length=4, image_seq_length=4,
vision_feature_layer=-1, vision_feature_layer=-1,
ignore_index=-100, ignore_index=-100,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_index=1, image_token_index=1,
num_channels=3, num_channels=3,
image_size=30, image_size=30,
@ -80,9 +77,9 @@ class Mistral3VisionText2TextModelTester:
"rms_norm_eps": 1e-05, "rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0, "rope_theta": 1000000000.0,
"sliding_window": None, "sliding_window": None,
"bos_token_id": 0, "bos_token_id": 2,
"eos_token_id": 0, "eos_token_id": 3,
"pad_token_id": 0, "pad_token_id": 4,
}, },
vision_config={ vision_config={
"model_type": "pixtral", "model_type": "pixtral",
@ -98,9 +95,9 @@ class Mistral3VisionText2TextModelTester:
): ):
self.parent = parent self.parent = parent
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.bos_token_id = bos_token_id self.bos_token_id = text_config["bos_token_id"]
self.eos_token_id = eos_token_id self.eos_token_id = text_config["eos_token_id"]
self.pad_token_id = pad_token_id self.pad_token_id = text_config["pad_token_id"]
self.image_token_index = image_token_index self.image_token_index = image_token_index
self.model_type = model_type self.model_type = model_type
self.text_config = text_config self.text_config = text_config
@ -209,49 +206,6 @@ class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip(reason="Compile not yet supported because in LLava models") @unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_compile_dynamic(self):
pass pass

View File

@ -199,49 +199,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
self.model_tester = PaliGemmaVisionText2TextModelTester(self) self.model_tester = PaliGemmaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """
@ -327,12 +284,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip(
reason="VLMs doesn't accept inputs embeds and pixel values at the same time. So if the test passed for backbone LM, it passes for VLM also"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@ -183,49 +183,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
self.model_tester = PaliGemma2VisionText2TextModelTester(self) self.model_tester = PaliGemma2VisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """
@ -311,12 +268,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip(
reason="VLMs doesn't accept inputs embeds and pixel values at the same time. So if the test passed for backbone LM, it passes for VLM also"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@ -22,7 +22,6 @@ from urllib.request import urlopen
import librosa import librosa
import requests import requests
from parameterized import parameterized
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
@ -289,10 +288,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
def test_sdpa_can_dispatch_on_flash(self): def test_sdpa_can_dispatch_on_flash(self):
pass pass
@unittest.skip(reason="QwenOmniThinker does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="QwenOmniThinker does not support output_hidden_states test") @unittest.skip(reason="QwenOmniThinker does not support output_hidden_states test")
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
pass pass
@ -337,11 +332,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@unittest.skip("Cannot generate from inputs embeds")
def test_generate_from_inputs_embeds(self):
pass
@unittest.skip("Cannot do contrastive generation, has custom `generate()`") @unittest.skip("Cannot do contrastive generation, has custom `generate()`")
def test_contrastive_generate(self): def test_contrastive_generate(self):
pass pass

View File

@ -357,39 +357,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_model_is_small(self): def test_model_is_small(self):
pass pass
@unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
super().test_prompt_lookup_decoding_matches_greedy_search() super().test_prompt_lookup_decoding_matches_greedy_search()
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
# TODO: @raushan
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@require_torch @require_torch
class Qwen2_5_VLIntegrationTest(unittest.TestCase): class Qwen2_5_VLIntegrationTest(unittest.TestCase):

View File

@ -313,35 +313,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_model_is_small(self): def test_model_is_small(self):
pass pass
@unittest.skip(
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
# TODO: @raushan
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@require_torch @require_torch
class Qwen2VLIntegrationTest(unittest.TestCase): class Qwen2VLIntegrationTest(unittest.TestCase):

View File

@ -181,14 +181,6 @@ class SmolVLMModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
def test_inputs_embeds():
pass
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Model does not support padding right") @unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
pass pass
@ -347,10 +339,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
self.model_tester = SmolVLMVisionText2TextModelTester(self) self.model_tester = SmolVLMVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=SmolVLMConfig, has_text_modality=False) self.config_tester = ConfigTester(self, config_class=SmolVLMConfig, has_text_modality=False)
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
def test_inputs_embeds():
pass
@unittest.skip(reason="Model does not support padding right") @unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
pass pass
@ -394,14 +382,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_0_greedy(self):
pass
@unittest.skip(reason="Unsupported")
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
@unittest.skip(reason="Unsupported") @unittest.skip(reason="Unsupported")
def test_generate_with_static_cache(self): def test_generate_with_static_cache(self):
pass pass

View File

@ -344,51 +344,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
continue continue
recursive_check(model_batched_output[key], model_row_output[key], model_name, key) recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values_images"]
del inputs["pixel_values_videos"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values_images"]
del inputs["pixel_values_videos"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """
Tests that VLMs through an error with explicit message saying what is wrong Tests that VLMs through an error with explicit message saying what is wrong

View File

@ -192,49 +192,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
def test_mismatching_num_image_tokens(self): def test_mismatching_num_image_tokens(self):
""" """

View File

@ -2829,7 +2829,9 @@ class ModelTesterMixin:
self.skipTest(reason="This model doesn't use `inputs_embeds`") self.skipTest(reason="This model doesn't use `inputs_embeds`")
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 1
)
wte = model.get_input_embeddings() wte = model.get_input_embeddings()
if not self.is_encoder_decoder: if not self.is_encoder_decoder: