[VLMs] add helpers to get multimodal encodings (#37743)

* add helpers in VLMs

* fix tests and copies

* fix blip tests

* make fix-copies

* fix copies

* fixup
This commit is contained in:
Raushan Turganbay 2025-05-16 13:20:10 +02:00 committed by GitHub
parent 1e921a3a9c
commit 3ab47b6ce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1000 additions and 593 deletions

View File

@ -1197,7 +1197,7 @@ class AriaModel(AriaPreTrainedModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: torch.FloatTensor = None,
pixel_mask: Optional[torch.FloatTensor] = None,
vision_feature_layer: int = -1,
):
"""
@ -1208,13 +1208,16 @@ class AriaModel(AriaPreTrainedModel):
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`, *optional*):
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True

View File

@ -1325,7 +1325,7 @@ class AriaModel(LlavaModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: torch.FloatTensor = None,
pixel_mask: Optional[torch.FloatTensor] = None,
vision_feature_layer: int = -1,
):
"""
@ -1336,13 +1336,16 @@ class AriaModel(LlavaModel):
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`, *optional*):
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True

View File

@ -213,8 +213,8 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
**kwargs,
):
"""
@ -223,16 +223,25 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

View File

@ -1956,6 +1956,50 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
def get_image_features(
self,
pixel_values: torch.FloatTensor,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs[0]
# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
if return_dict:
return language_model_inputs, vision_outputs, query_outputs
return language_model_inputs
@auto_docstring
def forward(
self,
@ -2047,37 +2091,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0]
# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)

View File

@ -904,6 +904,12 @@ class ChameleonModel(ChameleonPreTrainedModel):
self.embed_tokens = value
def get_image_tokens(self, pixel_values: torch.FloatTensor):
logger.warning(
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
)
return self.get_image_featues(pixel_values)
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
@ -957,7 +963,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
image_tokens = self.get_image_features(pixel_values)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()

View File

@ -1586,6 +1586,12 @@ class Emu3Model(Emu3PreTrainedModel):
self.text_model.set_input_embeddings(value)
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
logger.warning(
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
)
return self.get_image_featues(pixel_values)
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
@ -1662,7 +1668,7 @@ class Emu3Model(Emu3PreTrainedModel):
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
image_tokens = self.get_image_features(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

View File

@ -941,6 +941,12 @@ class Emu3Model(Emu3PreTrainedModel):
self.text_model.set_input_embeddings(value)
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
logger.warning(
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
)
return self.get_image_featues(pixel_values)
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
@ -1017,7 +1023,7 @@ class Emu3Model(Emu3PreTrainedModel):
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
image_tokens = self.get_image_features(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

View File

@ -125,9 +125,25 @@ class FuyuModel(FuyuPreTrainedModel):
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
)
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
output_embeddings.device
)
return output_embeddings
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
for patch in pixel_values
]
return patch_embeddings
@auto_docstring
def forward(
self,
@ -185,12 +201,7 @@ class FuyuModel(FuyuPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
.squeeze(0)
.to(inputs_embeds.device)
for patch in image_patches
]
patch_embeddings = self.get_image_features(image_patches)
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,

View File

@ -961,13 +961,57 @@ class Idefics2Model(Idefics2PreTrainedModel):
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
"""
num_images, _, vision_hidden_size = image_hidden_states.shape
special_image_token_mask = input_ids == self.image_token_id
new_inputs_embeds = inputs_embeds.clone()
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device)
return new_inputs_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
pixel_attention_mask (`torch.LongTensor`, *optional*):
The attention mask indicating padded regions in the image.
"""
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
image_hidden_states = image_hidden_states.last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
)
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
return image_hidden_states
@can_return_tuple
@auto_docstring(
custom_intro="""
@ -1052,45 +1096,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
)
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

View File

@ -692,16 +692,59 @@ class Idefics3Model(Idefics3PreTrainedModel):
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
"""
num_images, _, vision_hidden_size = image_hidden_states.shape
special_image_token_mask = input_ids == self.image_token_id
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
new_inputs_embeds = inputs_embeds.clone()
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
new_inputs_embeds[special_image_token_mask] = image_hidden_states
return new_inputs_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
pixel_attention_mask (`torch.LongTensor`, *optional*):
The attention mask indicating padded regions in the image.
"""
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
image_hidden_states.last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
return image_hidden_states
@can_return_tuple
@auto_docstring(
custom_intro="""
@ -774,43 +817,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states)
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

View File

@ -1466,6 +1466,55 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
def get_image_features(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.LongTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
if return_dict:
return language_model_inputs, vision_outputs, query_outputs
return language_model_inputs
@can_return_tuple
@auto_docstring
def forward(
@ -1555,40 +1604,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
pixel_values,
qformer_input_ids=qformer_input_ids,
qformer_attention_mask=qformer_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
@ -1690,30 +1714,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
self._preprocess_accelerate()
batch_size = pixel_values.shape[0]
image_embeds = self.vision_model(
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
pixel_values,
return_dict=True,
qformer_input_ids=qformer_input_ids,
qformer_attention_mask=qformer_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
language_model_inputs = self.language_projection(query_output)
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
@ -1722,7 +1729,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_id", None) is not None:
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
input_ids = input_ids.repeat(batch_size, 1)
if attention_mask is None:

View File

@ -1470,6 +1470,23 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
def get_image_features(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.LongTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
pass
@can_return_tuple
@auto_docstring
def forward(
@ -1582,50 +1599,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
pixel_values,
qformer_input_ids=qformer_input_ids,
qformer_attention_mask=qformer_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
@ -1726,39 +1708,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
# preprocess for `accelerate`
self._preprocess_accelerate()
# we process in a batched way, later unbatch it back (video has frames=4)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
image_embeds = self.vision_model(
batch_size = pixel_values.shape[0]
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
pixel_values,
return_dict=True,
qformer_input_ids=qformer_input_ids,
qformer_attention_mask=qformer_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
language_model_inputs = self.language_projection(query_output)
# unbatch the embeddings back by moving frames to seq-len
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
@ -1767,7 +1725,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_id", None) is not None:
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
input_ids = input_ids.repeat(batch_size, 1)
if attention_mask is None:
@ -1807,6 +1765,65 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
return outputs
def get_video_features(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.LongTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
if return_dict:
return language_model_inputs, vision_outputs, query_outputs
return language_model_inputs
__all__ = [
"InstructBlipVideoVisionModel",

View File

@ -295,6 +295,76 @@ class InstructBlipVideoModel(InstructBlipModel):
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
def get_video_features(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.LongTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
if return_dict:
return language_model_inputs, vision_outputs, query_outputs
return language_model_inputs
# Model supports only videos
def get_image_features(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.LongTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
pass
def forward(
self,
pixel_values: torch.FloatTensor,
@ -370,50 +440,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
pixel_values,
qformer_input_ids=qformer_input_ids,
qformer_attention_mask=qformer_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
@ -514,39 +549,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
# preprocess for `accelerate`
self._preprocess_accelerate()
# we process in a batched way, later unbatch it back (video has frames=4)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
image_embeds = self.vision_model(
batch_size = pixel_values.shape[0]
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
pixel_values,
return_dict=True,
qformer_input_ids=qformer_input_ids,
qformer_attention_mask=qformer_attention_mask,
interpolate_pos_encoding=interpolate_pos_encoding,
).last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
language_model_inputs = self.language_projection(query_output)
# unbatch the embeddings back by moving frames to seq-len
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
@ -555,7 +566,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_id", None) is not None:
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
input_ids = input_ids.repeat(batch_size, 1)
if attention_mask is None:

View File

@ -1625,6 +1625,37 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
def set_input_embeddings(self, value):
self.text_model.model.embed_tokens = value
def get_image_features(
self,
pixel_values: torch.FloatTensor,
return_attentions: Optional[bool] = False,
interpolate_pos_encoding: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
return_attentions (`bool`, *optional*, defaults to `False`):
Whether to return `projection_attentions` or not.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate positional embeddings or not.
"""
vision_model_output = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
# The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
# normalized features
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
if return_attentions:
return image_embeds, projection_attentions
return image_embeds
@can_return_tuple
@auto_docstring
def forward(
@ -1696,19 +1727,9 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
if image_embeds is None:
if pixel_values is None:
raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
vision_model_output = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
image_embeds, projection_attentions = self.get_image_features(
pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding
)
# The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
# normalized features
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
outputs = self.text_model(
input_ids=input_ids,

View File

@ -183,8 +183,8 @@ class LlavaModel(LlavaPreTrainedModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
**kwargs,
):
"""
@ -193,16 +193,25 @@ class LlavaModel(LlavaPreTrainedModel):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

View File

@ -365,8 +365,8 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -376,17 +376,26 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(

View File

@ -419,8 +419,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -430,17 +430,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
@ -600,8 +609,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
@ -609,17 +618,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optiona;*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)

View File

@ -252,8 +252,8 @@ class LlavaNextVideoModel(LlavaNextModel):
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -263,17 +263,26 @@ class LlavaNextVideoModel(LlavaNextModel):
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
@ -311,8 +320,8 @@ class LlavaNextVideoModel(LlavaNextModel):
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
@ -320,17 +329,26 @@ class LlavaNextVideoModel(LlavaNextModel):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optiona;*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)

View File

@ -419,8 +419,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -430,17 +430,26 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(

View File

@ -255,8 +255,8 @@ class Mistral3Model(Mistral3PreTrainedModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
**kwargs,
):
"""
@ -265,15 +265,19 @@ class Mistral3Model(Mistral3PreTrainedModel):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`):
image_sizes (`torch.Tensor`, *optional*):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)

View File

@ -135,8 +135,8 @@ class Mistral3Model(LlavaModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
**kwargs,
):
"""
@ -145,15 +145,19 @@ class Mistral3Model(LlavaModel):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`):
image_sizes (`torch.Tensor`, *optional*):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)

View File

@ -2187,6 +2187,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
return video_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def get_audio_features(
self,
input_features: torch.FloatTensor,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
):
"""
Encodes audios into continuous embeddings that can be forwarded to the language model.
Args:
input_features (`torch.FloatTensor`):
The tensors corresponding to the input audios.
feature_attention_mask (`torch.LongTensor`, *optional*):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
The length of feature shape of each audio in LLM.
"""
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
else:
audio_feature_lengths = None
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
)
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
audio_outputs = self.audio_tower(
input_features,
feature_lens=feature_lens,
aftercnn_lens=audio_feat_lengths,
)
audio_features = audio_outputs.last_hidden_state
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
raise ValueError("length of audio_features should match audio_output_lengths")
return audio_features
@auto_docstring
def forward(
self,
@ -2284,11 +2353,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
if input_features is not None:
audio_features = self.get_audio_features(
input_features,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
)
audio_mask = (
(input_ids == self.config.audio_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
else:
audio_feature_lengths = None
if attention_mask is not None and position_ids is None:
if (
cache_position is None
@ -2315,63 +2430,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
if inputs_embeds is None:
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
if input_features is not None:
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
)
feature_lens = (
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
)
audio_outputs = self.audio_tower(
input_features,
feature_lens=feature_lens,
aftercnn_lens=audio_feat_lengths,
)
audio_features = audio_outputs.last_hidden_state
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
raise ValueError("length of audio_features should match audio_output_lengths")
audio_mask = (
(input_ids == self.config.audio_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model(
attention_mask=attention_mask,
position_ids=position_ids,

View File

@ -2181,6 +2181,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
return video_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def get_audio_features(
self,
input_features: torch.FloatTensor,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
):
"""
Encodes audios into continuous embeddings that can be forwarded to the language model.
Args:
input_features (`torch.FloatTensor`):
The tensors corresponding to the input audios.
feature_attention_mask (`torch.LongTensor`, *optional*):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
The length of feature shape of each audio in LLM.
"""
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
else:
audio_feature_lengths = None
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
)
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
audio_outputs = self.audio_tower(
input_features,
feature_lens=feature_lens,
aftercnn_lens=audio_feat_lengths,
)
audio_features = audio_outputs.last_hidden_state
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
raise ValueError("length of audio_features should match audio_output_lengths")
return audio_features
@auto_docstring
def forward(
self,
@ -2278,11 +2347,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
if input_features is not None:
audio_features = self.get_audio_features(
input_features,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
)
audio_mask = (
(input_ids == self.config.audio_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
else:
audio_feature_lengths = None
if attention_mask is not None and position_ids is None:
if (
cache_position is None
@ -2309,63 +2424,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
if inputs_embeds is None:
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text , audios , image and video
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
if input_features is not None:
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
)
feature_lens = (
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
)
audio_outputs = self.audio_tower(
input_features,
feature_lens=feature_lens,
aftercnn_lens=audio_feat_lengths,
)
audio_features = audio_outputs.last_hidden_state
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
raise ValueError("length of audio_features should match audio_output_lengths")
audio_mask = (
(input_ids == self.config.audio_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model(
attention_mask=attention_mask,
position_ids=position_ids,

View File

@ -1583,6 +1583,36 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
return position_ids, mrope_position_deltas
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
return video_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
@auto_docstring
def forward(
self,
@ -1627,8 +1657,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
@ -1645,8 +1674,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:

View File

@ -646,8 +646,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
@ -664,8 +663,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:

View File

@ -1508,6 +1508,36 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
return position_ids, mrope_position_deltas
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
return video_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
@auto_docstring
def forward(
self,
@ -1549,9 +1579,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
@ -1567,9 +1596,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum()
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(

View File

@ -632,6 +632,52 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
return merged_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
pixel_attention_mask (`torch.LongTensor`, *optional*):
The attention mask indicating padded regions in the image.
"""
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
if not any(real_images_inds):
# no images, leave one empty image.
real_images_inds[0] = True
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=[pixel_values.shape[i] for i in (0, 2, 3)],
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
image_hidden_states = image_hidden_states.last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states)
return image_hidden_states
@can_return_tuple
@auto_docstring(
custom_intro="""
@ -704,48 +750,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
if not any(real_images_inds):
# no images, leave one empty image.
real_images_inds[0] = True
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=[pixel_values.shape[i] for i in (0, 2, 3)],
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states)
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

View File

@ -22,7 +22,7 @@ from torch import nn
from ...cache_utils import DynamicCache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import logging
from ...utils import auto_docstring, can_return_tuple, logging
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
from ..idefics3.modeling_idefics3 import (
@ -195,6 +195,64 @@ class SmolVLMModel(Idefics3Model):
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
return merged_embeds
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
pixel_attention_mask (`torch.LongTensor`, *optional*):
The attention mask indicating padded regions in the image.
"""
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
if not any(real_images_inds):
# no images, leave one empty image.
real_images_inds[0] = True
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=[pixel_values.shape[i] for i in (0, 2, 3)],
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
image_hidden_states = image_hidden_states.last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states)
return image_hidden_states
@can_return_tuple
@auto_docstring(
custom_intro="""
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
max_num_images is the maximum number of images among the batch_size samples in the batch.
Padding images are not needed beyond padding the pixel_values at the entrance of the model.
For efficiency, we only pass through the vision_model's forward the real images by
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
"""
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -249,48 +307,7 @@ class SmolVLMModel(Idefics3Model):
if pixel_values is not None and image_hidden_states is not None:
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
elif pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
if not any(real_images_inds):
# no images, leave one empty image.
real_images_inds[0] = True
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=[pixel_values.shape[i] for i in (0, 2, 3)],
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(
batch_size * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state
# Modality projection & resampling
image_hidden_states = self.connector(image_hidden_states)
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
elif image_hidden_states is not None:
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

View File

@ -205,8 +205,8 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -214,16 +214,25 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
Args:
pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
@ -249,7 +258,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
def get_video_features(
self,
pixel_values_videos: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_layer: Optional[Union[int, List[int]]] = None,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
@ -257,7 +266,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
Args:
pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input videos.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
@ -265,6 +274,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`).
frames (`int`): Number of frames the videos have.
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)