mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[VLMs] support passing embeds along with pixels (#38467)
* VLMs can work with embeds now * update more models * fix tests * fix copies * fixup * fix * style * unskip tests * fix copies * fix tests * style * omni modality models * qwen models had extra indentation * fix some other tests * fix copies * fix test last time * unrelated changes revert * we can't rely only on embeds * delete file * de-flake mistral3 * fix qwen models * fix style * fix tests * fix copies * deflake the test * modular reverted by fixes, fix again * flaky test, overwritten * fix copies * style
This commit is contained in:
parent
20901f1d68
commit
f8b88866f5
@ -733,7 +733,9 @@ class GenerationMixin(ContinuousMixin):
|
||||
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
|
||||
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
|
||||
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
||||
if not self.config.is_encoder_decoder:
|
||||
if model_kwargs["inputs_embeds"] is None:
|
||||
model_kwargs.pop("inputs_embeds")
|
||||
elif not self.config.is_encoder_decoder:
|
||||
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
||||
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
|
||||
)
|
||||
@ -748,10 +750,11 @@ class GenerationMixin(ContinuousMixin):
|
||||
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
||||
inputs, bos_token_id, model_kwargs=model_kwargs
|
||||
)
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
else:
|
||||
if inputs is not None:
|
||||
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
|
||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
||||
|
@ -1113,11 +1113,12 @@ class AriaModel(AriaPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
image_embeds = input_ids == self.config.image_token_id
|
||||
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
|
@ -1446,11 +1446,12 @@ class AriaModel(LlavaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
image_embeds = input_ids == self.config.image_token_id
|
||||
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
|
@ -302,14 +302,14 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -223,14 +223,14 @@ class AyaVisionModel(LlavaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -1855,6 +1855,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
@ -1971,10 +1972,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -2066,14 +2068,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concating
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
@ -2146,6 +2159,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -2159,6 +2173,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
Returns:
|
||||
captions (list): A list of strings of length batch_size * num_captions.
|
||||
@ -2193,22 +2211,32 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||
|
@ -963,25 +963,28 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
n_image_tokens_in_text = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel():
|
||||
n_image_features = image_embeds.shape[0] * image_embeds.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
|
||||
)
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
|
@ -1537,20 +1537,26 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
image_embeds = self.get_image_features(pixel_values, image_sizes)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -1033,20 +1033,26 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
image_embeds = self.get_image_features(pixel_values, image_sizes)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -206,14 +206,22 @@ class FuyuModel(FuyuPreTrainedModel):
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = self.get_image_features(image_patches)
|
||||
patch_embeddings = torch.cat(patch_embeddings, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
|
||||
if image_patches is not None:
|
||||
patch_embeddings = self.get_image_features(image_patches)
|
||||
patch_embeddings = torch.cat(patch_embeddings, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -898,9 +898,11 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
|
@ -800,9 +800,11 @@ class Gemma3Model(PaliGemmaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
|
@ -1237,50 +1237,59 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = image_mask.sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (video_mask).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = attention_mask
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
@ -1571,6 +1580,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
@ -1585,9 +1595,29 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
||||
"""
|
||||
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
if inputs_embeds is not None:
|
||||
is_image = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_start = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_end = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
else:
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
|
||||
# Cumulative sum to track if we're inside a video span
|
||||
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
||||
@ -1623,7 +1653,9 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
def _expand_dict_for_generation_visual(dict_to_expand):
|
||||
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
||||
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
||||
)
|
||||
|
||||
def _repeat_interleave_samples(x, lengths, repeat_times):
|
||||
samples = torch.split(x, lengths)
|
||||
@ -1679,10 +1711,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
# input_ids is required for expanding visual inputs
|
||||
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
||||
if input_ids is not None and input_ids.numel() != 0:
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
@ -1237,50 +1237,59 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = image_mask.sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (video_mask).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = attention_mask
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
@ -1500,6 +1509,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
@ -1514,9 +1524,29 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
||||
"""
|
||||
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
if inputs_embeds is not None:
|
||||
is_image = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_start = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_end = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
else:
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
|
||||
# Cumulative sum to track if we're inside a video span
|
||||
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
||||
|
@ -648,24 +648,27 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -339,24 +339,27 @@ class GotOcr2Model(LlavaModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -933,10 +933,18 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device)
|
||||
return new_inputs_embeds
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
|
||||
return inputs_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
@ -1041,25 +1049,8 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_seen_tokens = 0
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache:
|
||||
if not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
||||
@ -1072,7 +1063,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
|
||||
if image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
@ -1094,9 +1085,6 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if return_legacy_cache and use_cache:
|
||||
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
|
||||
|
||||
return Idefics2BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -1304,37 +1292,11 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
# but IDEFICS requires both ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs["input_ids"] = input_ids
|
||||
|
||||
if image_hidden_states is not None:
|
||||
if image_hidden_states is not None or cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
model_inputs["pixel_attention_mask"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs=outputs,
|
||||
model_kwargs=model_kwargs,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
# Get the precomputed image_hidden_states
|
||||
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
__all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]
|
||||
|
@ -663,15 +663,18 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
# Flatten `image_hidden_states` if not flat yet
|
||||
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
# cast to the dtype of the input_embeds to support quantized models
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states
|
||||
return new_inputs_embeds
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
|
||||
return inputs_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
@ -773,11 +776,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
|
||||
@ -790,7 +790,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None:
|
||||
if image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
@ -1042,28 +1042,11 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
# but IDEFICS requires both ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs["input_ids"] = input_ids
|
||||
|
||||
if image_hidden_states is not None:
|
||||
if image_hidden_states is not None or cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
model_inputs["pixel_attention_mask"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
|
||||
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs=outputs,
|
||||
model_kwargs=model_kwargs,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
# Get the precomputed image_hidden_states
|
||||
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||||
return model_kwargs
|
||||
|
||||
|
||||
__all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]
|
||||
|
@ -1255,6 +1255,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@ -1328,12 +1329,20 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1513,6 +1522,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -1604,15 +1614,26 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
@ -1673,6 +1694,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -1690,6 +1712,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
@ -1712,23 +1736,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
|
@ -1251,6 +1251,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@ -1334,12 +1335,20 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1485,6 +1494,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -1599,15 +1609,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
@ -1668,6 +1689,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -1685,6 +1707,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
@ -1708,23 +1732,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
|
@ -202,6 +202,7 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@ -255,12 +256,20 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -372,6 +381,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -451,15 +461,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
@ -520,6 +541,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -537,6 +559,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
@ -560,23 +584,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
|
@ -710,14 +710,14 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -641,14 +641,14 @@ class InternVLModel(LlavaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -1102,23 +1102,21 @@ class JanusModel(JanusPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_attention_mask = image_attention_mask.all(-1)
|
||||
else:
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
embed_dim = inputs_embeds.shape[-1]
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
|
||||
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||
image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||
|
||||
|
@ -955,23 +955,21 @@ class JanusModel(JanusPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_attention_mask = image_attention_mask.all(-1)
|
||||
else:
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
embed_dim = inputs_embeds.shape[-1]
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
|
||||
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||
image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||
|
||||
|
@ -1467,25 +1467,19 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
image_embeds_position_mask=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
cache_position=None,
|
||||
**model_kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
# Kosmos2 has offset for position ids, so we need to create them correctly
|
||||
position_ids = create_position_ids_from_input_ids(
|
||||
input_ids,
|
||||
padding_idx=self.config.pad_token_id,
|
||||
past_key_values_length=0,
|
||||
)
|
||||
|
||||
if past_key_values is not None:
|
||||
image_embeds = None
|
||||
image_embeds_position_mask = None
|
||||
# appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation)
|
||||
elif image_embeds_position_mask is not None:
|
||||
batch_size, seq_len = input_ids.size()
|
||||
batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size()
|
||||
mask_len = image_embeds_position_mask.size()[-1]
|
||||
image_embeds_position_mask = torch.cat(
|
||||
(
|
||||
@ -1501,11 +1495,13 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
attention_mask=attention_mask,
|
||||
image_embeds=image_embeds,
|
||||
image_embeds_position_mask=image_embeds_position_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
**model_kwargs,
|
||||
)
|
||||
# Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
|
||||
model_inputs.pop("position_ids", None)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@ -1875,6 +1871,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# in order to allow `inputs` argument (as in `GenerationMixin`)
|
||||
@ -1900,6 +1897,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
attention_mask=attention_mask,
|
||||
image_embeds=image_embeds,
|
||||
image_embeds_position_mask=image_embeds_position_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -1358,27 +1358,28 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
if n_image_tokens != projected_vision_flat.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"Mismatch: final_mask wants {n_image_tokens} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
projected_vision_flat = projected_vision_flat.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -284,14 +284,14 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -468,11 +468,6 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -485,10 +480,18 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -519,12 +519,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -537,10 +531,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -559,10 +561,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -440,12 +440,6 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -458,10 +452,18 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -480,10 +482,18 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -551,12 +551,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -571,10 +565,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -595,10 +597,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_video_mask = special_video_mask.all(-1)
|
||||
else:
|
||||
special_video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_video_mask).sum()
|
||||
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -535,12 +535,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -555,10 +549,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -579,10 +581,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_video_mask = special_video_mask.all(-1)
|
||||
else:
|
||||
special_video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_video_mask).sum()
|
||||
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -308,11 +308,6 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -324,10 +319,18 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -204,11 +204,6 @@ class Mistral3Model(LlavaModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -220,10 +215,18 @@ class Mistral3Model(LlavaModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -331,9 +331,11 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
|
@ -1903,43 +1903,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
)
|
||||
if input_ids is None:
|
||||
audio_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
audio_mask = audio_mask.all(-1)
|
||||
else:
|
||||
audio_mask = input_ids == self.config.audio_token_id
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
|
@ -2350,43 +2350,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
)
|
||||
if input_ids is None:
|
||||
audio_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
audio_mask = audio_mask.all(-1)
|
||||
else:
|
||||
audio_mask = input_ids == self.config.audio_token_id
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
|
@ -1245,41 +1245,51 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
n_image_tokens = (image_mask).sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (video_mask).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = (
|
||||
@ -1586,6 +1596,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
@ -1603,10 +1614,31 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
|
||||
vision_start_mask = input_ids == vision_start_token_id
|
||||
if inputs_embeds is not None:
|
||||
vision_start_mask = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
image_mask = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
video_mask = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
else:
|
||||
vision_start_mask = input_ids == vision_start_token_id
|
||||
image_mask = input_ids == image_token_id
|
||||
video_mask = input_ids == video_token_id
|
||||
|
||||
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
|
||||
image_mask = input_ids == image_token_id
|
||||
video_mask = input_ids == video_token_id
|
||||
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
|
||||
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
|
||||
|
||||
@ -1632,7 +1664,9 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
def _expand_dict_for_generation_visual(dict_to_expand):
|
||||
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
||||
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
||||
)
|
||||
|
||||
def _repeat_interleave_samples(x, lengths, repeat_times):
|
||||
samples = torch.split(x, lengths)
|
||||
@ -1688,10 +1722,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
# input_ids is required for expanding visual inputs
|
||||
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
||||
if input_ids is not None and input_ids.numel() != 0:
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
@ -609,41 +609,51 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
n_image_tokens = (image_mask).sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (video_mask).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = (
|
||||
|
@ -1182,41 +1182,52 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = image_mask.sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_video_tokens = (video_mask).sum(dim=1).sum(dim=0)[0]
|
||||
else:
|
||||
video_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
video_mask = video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_video_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = (
|
||||
@ -1480,6 +1491,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
@ -1497,10 +1509,31 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
|
||||
vision_start_mask = input_ids == vision_start_token_id
|
||||
if inputs_embeds is not None:
|
||||
vision_start_mask = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
image_mask = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
video_mask = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
else:
|
||||
vision_start_mask = input_ids == vision_start_token_id
|
||||
image_mask = input_ids == image_token_id
|
||||
video_mask = input_ids == video_token_id
|
||||
|
||||
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
|
||||
image_mask = input_ids == image_token_id
|
||||
video_mask = input_ids == video_token_id
|
||||
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
|
||||
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
|
||||
|
||||
@ -1526,7 +1559,9 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
def _expand_dict_for_generation_visual(dict_to_expand):
|
||||
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
||||
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
||||
)
|
||||
|
||||
def _repeat_interleave_samples(x, lengths, repeat_times):
|
||||
samples = torch.split(x, lengths)
|
||||
@ -1582,10 +1617,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
# input_ids is required for expanding visual inputs
|
||||
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
||||
if input_ids is not None and input_ids.numel() != 0:
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
@ -595,7 +595,14 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
"""
|
||||
_, patch_size, _ = image_hidden_states.shape
|
||||
|
||||
image_mask = input_ids == self.image_token_id
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask[..., 0] # slice off the hidden dim
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
num_image_tokens = image_mask.sum(dim=1)
|
||||
if not torch.all(num_image_tokens % patch_size == 0):
|
||||
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
|
||||
@ -717,14 +724,8 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
|
||||
@ -732,12 +733,13 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
# START VISUAL INPUTS INTEGRATION
|
||||
if pixel_values is not None and image_hidden_states is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||
elif pixel_values is not None:
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if inputs_embeds is not None and image_hidden_states is not None:
|
||||
if pixel_values is not None:
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
|
||||
|
||||
if image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
@ -996,27 +998,11 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
# but IDEFICS requires both ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs["input_ids"] = input_ids
|
||||
|
||||
if image_hidden_states is not None:
|
||||
if image_hidden_states is not None or cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
model_inputs["pixel_attention_mask"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs=outputs,
|
||||
model_kwargs=model_kwargs,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
# Get the precomputed image_hidden_states
|
||||
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||||
return model_kwargs
|
||||
|
||||
|
||||
__all__ = ["SmolVLMForConditionalGeneration", "SmolVLMPreTrainedModel", "SmolVLMModel", "SmolVLMVisionTransformer"]
|
||||
|
@ -180,7 +180,14 @@ class SmolVLMModel(Idefics3Model):
|
||||
):
|
||||
_, patch_size, _ = image_hidden_states.shape
|
||||
|
||||
image_mask = input_ids == self.image_token_id
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask[..., 0] # slice off the hidden dim
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
num_image_tokens = image_mask.sum(dim=1)
|
||||
if not torch.all(num_image_tokens % patch_size == 0):
|
||||
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
|
||||
@ -296,14 +303,8 @@ class SmolVLMModel(Idefics3Model):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(input_ids.device)
|
||||
@ -311,12 +312,13 @@ class SmolVLMModel(Idefics3Model):
|
||||
# START VISUAL INPUTS INTEGRATION
|
||||
if pixel_values is not None and image_hidden_states is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||
elif pixel_values is not None:
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(input_ids.device)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if inputs_embeds is not None and image_hidden_states is not None:
|
||||
if pixel_values is not None:
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask).to(inputs_embeds.device)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=inputs_embeds.device)
|
||||
|
||||
if image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
|
@ -328,12 +328,6 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
|
||||
"time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -343,10 +337,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -359,10 +361,18 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0] * video_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -233,11 +233,6 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -246,10 +241,18 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
|
||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -136,11 +136,6 @@ class VipLlavaModel(LlavaModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -149,10 +144,18 @@ class VipLlavaModel(LlavaModel):
|
||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -118,27 +118,6 @@ from unittest.mock import patch
|
||||
from transformers.utils import is_sklearn_available
|
||||
|
||||
|
||||
# TODO: raushan remove this when VLMs start accepting input embeds
|
||||
VLM_CLASS_NAMES = [
|
||||
"llava",
|
||||
"idefics2",
|
||||
"idefics3",
|
||||
"mllama",
|
||||
"paligemma",
|
||||
"emu3",
|
||||
"gotocr2",
|
||||
"qwen2vl",
|
||||
"qwen2_5_vl",
|
||||
"ayavision",
|
||||
"janus",
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
"internvl",
|
||||
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`
|
||||
]
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
input_name = "input_ids"
|
||||
model_tester = None
|
||||
@ -1228,7 +1207,23 @@ class GenerationTesterMixin:
|
||||
"blip2", # overridden `generate()`
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
*VLM_CLASS_NAMES, # shouldn't suggest image tokens
|
||||
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
|
||||
"llava",
|
||||
"idefics2",
|
||||
"idefics3",
|
||||
"mllama",
|
||||
"paligemma",
|
||||
"emu3",
|
||||
"gotocr2",
|
||||
"qwen2vl",
|
||||
"qwen2_5_vl",
|
||||
"ayavision",
|
||||
"janus",
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
"internvl",
|
||||
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`,
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
@ -1641,6 +1636,58 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_random_inputs_embeds(self):
|
||||
"""
|
||||
Text-only: Tests that different `inputs_embeds` generate different outputs in models with `main_input=="input_ids"`.
|
||||
Some models have 'images' as main input and thus can't generate with random text embeddings.
|
||||
See `test_generate_from_inputs_embeds` for more general checks.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
continue
|
||||
|
||||
# No easy fix, let's skip the test for now
|
||||
has_complex_embeds_computation = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||
)
|
||||
|
||||
if model_class.main_input_name != "input_ids" or has_complex_embeds_computation:
|
||||
self.skipTest(
|
||||
"The model's main input name in not `input_ids` and we need kwargs from input dict as well."
|
||||
)
|
||||
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
}
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, **generation_kwargs)
|
||||
|
||||
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids=input_ids, inputs_embeds=random_embeds, **generation_kwargs
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
@ -1662,34 +1709,22 @@ class GenerationTesterMixin:
|
||||
continue
|
||||
|
||||
# There are a few exception patterns in this test:
|
||||
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
|
||||
requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"])
|
||||
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||
# 1 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||
# than calling the embedding layer with `input_ids`. Subcases of this exception:
|
||||
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||
# 1.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
|
||||
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
|
||||
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
# HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive;
|
||||
# this is similar to VLMs and should likely be standardized for similar audio models in the future,
|
||||
# then made generic here.
|
||||
if "granitespeech" in model_class.__name__.lower():
|
||||
inputs_dict.pop("input_features", None)
|
||||
|
||||
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||
# 1.B - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||
has_complex_embeds_computation = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||
)
|
||||
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
|
||||
# 2 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
|
||||
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
|
||||
missing_attention_mask = "attention_mask" not in inputs_dict
|
||||
|
||||
@ -1702,31 +1737,23 @@ class GenerationTesterMixin:
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
"use_cache": True,
|
||||
}
|
||||
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
|
||||
outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
if not has_complex_embeds_computation:
|
||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||
|
||||
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||
# be the same
|
||||
if not (requires_inputs_ids or missing_attention_mask):
|
||||
if not missing_attention_mask:
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
@ -1753,17 +1780,6 @@ class GenerationTesterMixin:
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
|
||||
|
||||
# Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
|
||||
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
|
||||
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
model.config.use_cache = True
|
||||
@ -1925,14 +1941,6 @@ class GenerationTesterMixin:
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
|
@ -189,49 +189,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
self.model_tester = AriaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
@ -270,14 +227,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_from_inputs_embeds_0_greedy(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
@ -62,7 +62,7 @@ class AyaVisionVisionText2TextModelTester:
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_index=1,
|
||||
image_token_index=2,
|
||||
num_channels=3,
|
||||
image_size=64,
|
||||
model_type="aya_vision",
|
||||
@ -183,49 +183,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||
def test_model_outputs_equivalence(self, **kwargs):
|
||||
pass
|
||||
@ -285,10 +242,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
@ -20,7 +20,6 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
|
||||
from transformers.testing_utils import (
|
||||
@ -674,15 +673,6 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
|
||||
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
|
||||
class Blip2TextModelTester:
|
||||
|
@ -355,49 +355,6 @@ class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unit
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ChameleonIntegrationTest(unittest.TestCase):
|
||||
|
@ -189,50 +189,6 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester = ColPaliForRetrievalModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@slow
|
||||
@require_vision
|
||||
def test_colpali_forward_inputs(self):
|
||||
|
@ -331,49 +331,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(
|
||||
"Emu3 has a VQ module that uses `weight.data` directly in forward which prevent offloding on that module"
|
||||
)
|
||||
|
@ -131,10 +131,6 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
@ -13,12 +13,10 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch GLM-4.1V model."""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@ -237,11 +235,6 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
@unittest.skip("Cannot generate from inputs embeds with pixel values")
|
||||
def test_generate_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Size mismatch")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
@ -250,34 +243,11 @@ class Glm4vModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot generate from inputs embeds with pixel values")
|
||||
@unittest.skip("Error with compilation")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
|
||||
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
|
||||
# TODO: @raushan
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["image_grid_thw"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# RoPE index doesn't match when using embeddings
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -51,9 +51,6 @@ class GotOcr2VisionText2TextModelTester:
|
||||
num_channels=3,
|
||||
ignore_index=-100,
|
||||
image_size=64,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_index=1,
|
||||
model_type="got_ocr2",
|
||||
is_training=True,
|
||||
@ -71,6 +68,9 @@ class GotOcr2VisionText2TextModelTester:
|
||||
"rope_theta": 10000,
|
||||
"mlp_ratio": 4,
|
||||
"tie_word_embeddings": True,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 3,
|
||||
"pad_token_id": 4,
|
||||
},
|
||||
vision_config={
|
||||
"num_hidden_layers": 2,
|
||||
@ -85,9 +85,9 @@ class GotOcr2VisionText2TextModelTester:
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = text_config["bos_token_id"]
|
||||
self.eos_token_id = text_config["eos_token_id"]
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.image_token_index = image_token_index
|
||||
self.model_type = model_type
|
||||
self.text_config = text_config
|
||||
@ -109,9 +109,6 @@ class GotOcr2VisionText2TextModelTester:
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
model_type=self.model_type,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
image_token_index=self.image_token_index,
|
||||
)
|
||||
|
||||
@ -127,7 +124,6 @@ class GotOcr2VisionText2TextModelTester:
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
# input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_index] = self.pad_token_id
|
||||
input_ids[:, : self.num_image_tokens] = self.image_token_index
|
||||
|
||||
@ -181,55 +177,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
|
||||
)
|
||||
|
@ -315,13 +315,6 @@ class IdeficsModelTester:
|
||||
def prepare_pixel_values(self):
|
||||
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@unittest.skip(reason="Idefics has a hard requirement on SDPA, skipping this test")
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -611,6 +604,12 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMi
|
||||
def test_sdpa_can_dispatch_non_composite_models(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Idefics can't do text-only inference")
|
||||
def test_generate_from_random_inputs_embeds(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -899,6 +898,12 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
def test_generation_tester_mixin_inheritance(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Idefics can't do text-only inference")
|
||||
def test_generate_from_random_inputs_embeds(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
|
@ -108,6 +108,7 @@ class Idefics2VisionText2TextModelTester:
|
||||
image_token_id=99,
|
||||
):
|
||||
self.parent = parent
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.num_images = num_images
|
||||
@ -158,6 +159,7 @@ class Idefics2VisionText2TextModelTester:
|
||||
|
||||
# For simplicity just set the last n tokens to the image token
|
||||
n_image_tokens_per_batch = self.num_images * self.perceiver_config["resampler_n_latents"]
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
inputs_dict = {
|
||||
|
@ -96,6 +96,7 @@ class Idefics3VisionText2TextModelTester:
|
||||
image_token_id=57,
|
||||
):
|
||||
self.parent = parent
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.num_images = num_images
|
||||
@ -148,6 +149,7 @@ class Idefics3VisionText2TextModelTester:
|
||||
|
||||
# For simplicity just set the last n tokens to the image token
|
||||
n_image_tokens_per_batch = self.seq_length
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, -n_image_tokens_per_batch:] = self.image_token_id
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
inputs_dict = {
|
||||
|
@ -20,7 +20,6 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
@ -522,12 +521,6 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -656,13 +649,6 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
"""
|
||||
|
@ -20,7 +20,6 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
@ -535,12 +534,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -669,13 +662,6 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
# They should result in very similar logits
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skip(
|
||||
"InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
|
||||
)
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
"""
|
||||
|
@ -63,9 +63,6 @@ class InternVLVisionText2TextModelTester:
|
||||
image_seq_length=64,
|
||||
vision_feature_layer=-1,
|
||||
ignore_index=-100,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_id=1,
|
||||
num_channels=3,
|
||||
image_size=64,
|
||||
@ -85,9 +82,9 @@ class InternVLVisionText2TextModelTester:
|
||||
"rope_theta": 10000,
|
||||
"mlp_ratio": 4,
|
||||
"tie_word_embeddings": True,
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 3,
|
||||
"eos_token_id": 4,
|
||||
"pad_token_id": 5,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": 32,
|
||||
@ -103,9 +100,9 @@ class InternVLVisionText2TextModelTester:
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = text_config["bos_token_id"]
|
||||
self.eos_token_id = text_config["eos_token_id"]
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.image_token_id = image_token_id
|
||||
self.model_type = model_type
|
||||
self.text_config = text_config
|
||||
@ -128,9 +125,6 @@ class InternVLVisionText2TextModelTester:
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
model_type=self.model_type,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
image_token_id=self.image_token_id,
|
||||
image_seq_length=self.image_seq_length,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
@ -148,7 +142,6 @@ class InternVLVisionText2TextModelTester:
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
# input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, : self.image_seq_length] = self.image_token_id
|
||||
|
||||
@ -222,49 +215,6 @@ class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
@ -153,6 +153,7 @@ class JanusVisionText2TextModelTester:
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
vq_config=self.get_vq_config(),
|
||||
image_token_id=self.image_token_index,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
@ -200,50 +201,6 @@ class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, uni
|
||||
self.model_tester = JanusVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=JanusConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["generation_mode"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# Overwrite inputs_embeds tests because we need to delete "pixel values" for VLMs.
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["generation_mode"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -457,14 +457,6 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
@unittest.skip(
|
||||
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten"
|
||||
)
|
||||
def test_generate_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@unittest.skip("KOSMOS-2 doesn't support padding")
|
||||
@ -613,6 +605,53 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
|
||||
# NOTE: overwritten because Kosmos with ids prepares position ids differently from embeds
|
||||
# If the model get ids, all pad tokens are masked from position ids. That is not possible with embeds
|
||||
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config.is_decoder = True
|
||||
|
||||
# Skip models without explicit support
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# Traditional way of generating text
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
input_ids[input_ids == config.get_text_config().pad_token_id] = 0
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"num_beams": num_beams,
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
"use_cache": True,
|
||||
}
|
||||
outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||
|
||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||
# be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
@ -196,49 +196,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
|
@ -222,49 +222,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
|
@ -86,7 +86,7 @@ class LlavaNextVideoVisionText2TextModelTester:
|
||||
"initializer_range": 0.02,
|
||||
"num_labels": 3,
|
||||
"num_choices": 4,
|
||||
"pad_token_id": 2,
|
||||
"pad_token_id": 3,
|
||||
},
|
||||
is_training=True,
|
||||
vision_config={
|
||||
@ -234,51 +234,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
|
@ -230,49 +230,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_odd_sized_image(self):
|
||||
# prepare model configuration
|
||||
config = self.model_tester.get_config()
|
||||
|
@ -57,9 +57,6 @@ class Mistral3VisionText2TextModelTester:
|
||||
image_seq_length=4,
|
||||
vision_feature_layer=-1,
|
||||
ignore_index=-100,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_index=1,
|
||||
num_channels=3,
|
||||
image_size=30,
|
||||
@ -80,9 +77,9 @@ class Mistral3VisionText2TextModelTester:
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 1000000000.0,
|
||||
"sliding_window": None,
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 0,
|
||||
"pad_token_id": 0,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 3,
|
||||
"pad_token_id": 4,
|
||||
},
|
||||
vision_config={
|
||||
"model_type": "pixtral",
|
||||
@ -98,9 +95,9 @@ class Mistral3VisionText2TextModelTester:
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = text_config["bos_token_id"]
|
||||
self.eos_token_id = text_config["eos_token_id"]
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.image_token_index = image_token_index
|
||||
self.model_type = model_type
|
||||
self.text_config = text_config
|
||||
@ -209,49 +206,6 @@ class Mistral3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in LLava models")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
@ -199,49 +199,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
self.model_tester = PaliGemmaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
@ -327,12 +284,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs doesn't accept inputs embeds and pixel values at the same time. So if the test passed for backbone LM, it passes for VLM also"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -183,49 +183,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
self.model_tester = PaliGemma2VisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
@ -311,12 +268,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs doesn't accept inputs embeds and pixel values at the same time. So if the test passed for backbone LM, it passes for VLM also"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -22,7 +22,6 @@ from urllib.request import urlopen
|
||||
|
||||
import librosa
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@ -289,10 +288,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="QwenOmniThinker does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="QwenOmniThinker does not support output_hidden_states test")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
@ -337,11 +332,6 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
@unittest.skip("Cannot generate from inputs embeds")
|
||||
def test_generate_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
@ -357,39 +357,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the tes for VLMs"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
super().test_prompt_lookup_decoding_matches_greedy_search()
|
||||
|
||||
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
|
||||
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
|
||||
# TODO: @raushan
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2_5_VLIntegrationTest(unittest.TestCase):
|
||||
|
@ -313,35 +313,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
# The multimodal base model embeds will not match ids, due to pixel values. We can't change base test
|
||||
# because in some models `pixel_values` are required. Will be fixed when we add support for merging `embeds+pixels`
|
||||
# TODO: @raushan
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||
|
@ -181,14 +181,6 @@ class SmolVLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
|
||||
def test_inputs_embeds():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Model does not support padding right")
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
pass
|
||||
@ -347,10 +339,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
||||
self.model_tester = SmolVLMVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SmolVLMConfig, has_text_modality=False)
|
||||
|
||||
@unittest.skip(reason="input_embeds cannot be passed in without input_ids")
|
||||
def test_inputs_embeds():
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Model does not support padding right")
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
pass
|
||||
@ -394,14 +382,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_from_inputs_embeds_0_greedy(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
@ -344,51 +344,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
continue
|
||||
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values_images"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values_images"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
|
@ -192,49 +192,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
# Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
|
@ -2829,7 +2829,9 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="This model doesn't use `inputs_embeds`")
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
pad_token_id = (
|
||||
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 1
|
||||
)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
|
Loading…
Reference in New Issue
Block a user