mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[VLMs] add helpers to get multimodal encodings (#37743)
* add helpers in VLMs * fix tests and copies * fix blip tests * make fix-copies * fix copies * fixup
This commit is contained in:
parent
1e921a3a9c
commit
3ab47b6ce3
@ -1197,7 +1197,7 @@ class AriaModel(AriaPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_mask: torch.FloatTensor = None,
|
||||
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: int = -1,
|
||||
):
|
||||
"""
|
||||
@ -1208,13 +1208,16 @@ class AriaModel(AriaPreTrainedModel):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_mask (`torch.FloatTensor]`, *optional*):
|
||||
The tensors corresponding to the input image mask.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
||||
|
@ -1325,7 +1325,7 @@ class AriaModel(LlavaModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_mask: torch.FloatTensor = None,
|
||||
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: int = -1,
|
||||
):
|
||||
"""
|
||||
@ -1336,13 +1336,16 @@ class AriaModel(LlavaModel):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_mask (`torch.FloatTensor]`, *optional*):
|
||||
The tensors corresponding to the input image mask.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
||||
|
@ -213,8 +213,8 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -223,16 +223,25 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
|
@ -1956,6 +1956,50 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
if return_dict:
|
||||
return language_model_inputs, vision_outputs, query_outputs
|
||||
return language_model_inputs
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2047,37 +2091,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
@ -904,6 +904,12 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
self.embed_tokens = value
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -957,7 +963,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
image_tokens = self.get_image_features(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
|
@ -1586,6 +1586,12 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
self.text_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -1662,7 +1668,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
image_tokens = self.get_image_features(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
@ -941,6 +941,12 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
self.text_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -1017,7 +1023,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
image_tokens = self.get_image_features(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
@ -125,9 +125,25 @@ class FuyuModel(FuyuPreTrainedModel):
|
||||
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
|
||||
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
|
||||
)
|
||||
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
|
||||
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
|
||||
output_embeddings.device
|
||||
)
|
||||
return output_embeddings
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
patch_embeddings = [
|
||||
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
|
||||
for patch in pixel_values
|
||||
]
|
||||
return patch_embeddings
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -185,12 +201,7 @@ class FuyuModel(FuyuPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = [
|
||||
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
|
||||
.squeeze(0)
|
||||
.to(inputs_embeds.device)
|
||||
for patch in image_patches
|
||||
]
|
||||
patch_embeddings = self.get_image_features(image_patches)
|
||||
inputs_embeds = self.gather_continuous_embeddings(
|
||||
word_embeddings=inputs_embeds,
|
||||
continuous_embeddings=patch_embeddings,
|
||||
|
@ -961,13 +961,57 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
num_images, _, vision_hidden_size = image_hidden_states.shape
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
|
||||
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device)
|
||||
return new_inputs_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||
The attention mask indicating padded regions in the image.
|
||||
"""
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||
image_hidden_states = image_hidden_states.last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
|
||||
)
|
||||
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
return image_hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -1052,45 +1096,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
if pixel_values is not None and image_hidden_states is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||
elif pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
).last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
|
||||
)
|
||||
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
|
@ -692,16 +692,59 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
num_images, _, vision_hidden_size = image_hidden_states.shape
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
|
||||
# cast to the dtype of the input_embeds to support quantized models
|
||||
reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
||||
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states
|
||||
return new_inputs_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||
The attention mask indicating padded regions in the image.
|
||||
"""
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||
image_hidden_states.last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
|
||||
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
return image_hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -774,43 +817,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
if pixel_values is not None and image_hidden_states is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||
elif pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
).last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states)
|
||||
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
|
@ -1466,6 +1466,55 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: torch.LongTensor,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
if return_dict:
|
||||
return language_model_inputs, vision_outputs, query_outputs
|
||||
return language_model_inputs
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@ -1555,40 +1604,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||
pixel_values,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
@ -1690,30 +1714,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
self._preprocess_accelerate()
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
image_embeds = self.vision_model(
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||
pixel_values,
|
||||
return_dict=True,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
).last_hidden_state
|
||||
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
||||
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
@ -1722,7 +1729,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
if attention_mask is None:
|
||||
|
@ -1470,6 +1470,23 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
if hasattr(self.language_model, "_hf_hook"):
|
||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: torch.LongTensor,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
pass
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@ -1582,50 +1599,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||
pixel_values,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
@ -1726,39 +1708,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
# preprocess for `accelerate`
|
||||
self._preprocess_accelerate()
|
||||
|
||||
# we process in a batched way, later unbatch it back (video has frames=4)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
image_embeds = self.vision_model(
|
||||
batch_size = pixel_values.shape[0]
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||
pixel_values,
|
||||
return_dict=True,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
).last_hidden_state
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
||||
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch the embeddings back by moving frames to seq-len
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
@ -1767,7 +1725,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
if attention_mask is None:
|
||||
@ -1807,6 +1765,65 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
|
||||
return outputs
|
||||
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: torch.LongTensor,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
if return_dict:
|
||||
return language_model_inputs, vision_outputs, query_outputs
|
||||
return language_model_inputs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InstructBlipVideoVisionModel",
|
||||
|
@ -295,6 +295,76 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
|
||||
|
||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: torch.LongTensor,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
if return_dict:
|
||||
return language_model_inputs, vision_outputs, query_outputs
|
||||
return language_model_inputs
|
||||
|
||||
# Model supports only videos
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: torch.LongTensor,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
@ -370,50 +440,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||
pixel_values,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=True,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
@ -514,39 +549,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
# preprocess for `accelerate`
|
||||
self._preprocess_accelerate()
|
||||
|
||||
# we process in a batched way, later unbatch it back (video has frames=4)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
image_embeds = self.vision_model(
|
||||
batch_size = pixel_values.shape[0]
|
||||
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||
pixel_values,
|
||||
return_dict=True,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
).last_hidden_state
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
||||
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch the embeddings back by moving frames to seq-len
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
@ -555,7 +566,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
|
||||
if attention_mask is None:
|
||||
|
@ -1625,6 +1625,37 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.model.embed_tokens = value
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
return_attentions: Optional[bool] = False,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
return_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return `projection_attentions` or not.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional embeddings or not.
|
||||
"""
|
||||
vision_model_output = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
# The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
|
||||
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
|
||||
# normalized features
|
||||
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
|
||||
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
|
||||
|
||||
if return_attentions:
|
||||
return image_embeds, projection_attentions
|
||||
return image_embeds
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
@ -1696,19 +1727,9 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
||||
if image_embeds is None:
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
|
||||
|
||||
vision_model_output = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
image_embeds, projection_attentions = self.get_image_features(
|
||||
pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
# The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
|
||||
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
|
||||
# normalized features
|
||||
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
|
||||
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
|
||||
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
|
@ -183,8 +183,8 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -193,16 +193,25 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
|
@ -365,8 +365,8 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -376,17 +376,26 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
|
@ -419,8 +419,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -430,17 +430,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
@ -600,8 +609,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -609,17 +618,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input video.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optiona;*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
batch_size, frames, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
@ -252,8 +252,8 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -263,17 +263,26 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
@ -311,8 +320,8 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -320,17 +329,26 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input video.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optiona;*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
batch_size, frames, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
@ -419,8 +419,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -430,17 +430,26 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
|
@ -255,8 +255,8 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -265,15 +265,19 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_sizes (`torch.Tensor`):
|
||||
image_sizes (`torch.Tensor`, *optional*):
|
||||
Tensor containing the image sizes as returned by the processor.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||
|
@ -135,8 +135,8 @@ class Mistral3Model(LlavaModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -145,15 +145,19 @@ class Mistral3Model(LlavaModel):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
image_sizes (`torch.Tensor`):
|
||||
image_sizes (`torch.Tensor`, *optional*):
|
||||
Tensor containing the image sizes as returned by the processor.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||
|
@ -2187,6 +2187,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_video_features(
|
||||
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input videos.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
return video_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
"""
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def get_audio_features(
|
||||
self,
|
||||
input_features: torch.FloatTensor,
|
||||
feature_attention_mask: Optional[torch.LongTensor] = None,
|
||||
audio_feature_lengths: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Encodes audios into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
input_features (`torch.FloatTensor`):
|
||||
The tensors corresponding to the input audios.
|
||||
feature_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
||||
The length of feature shape of each audio in LLM.
|
||||
"""
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||
else:
|
||||
audio_feature_lengths = None
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
)
|
||||
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
audio_outputs = self.audio_tower(
|
||||
input_features,
|
||||
feature_lens=feature_lens,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
|
||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||
|
||||
return audio_features
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2284,11 +2353,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
)
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||
else:
|
||||
audio_feature_lengths = None
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
if (
|
||||
cache_position is None
|
||||
@ -2315,63 +2430,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
)
|
||||
feature_lens = (
|
||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
)
|
||||
audio_outputs = self.audio_tower(
|
||||
input_features,
|
||||
feature_lens=feature_lens,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
outputs = self.model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
|
@ -2181,6 +2181,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_video_features(
|
||||
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input videos.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
return video_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
"""
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def get_audio_features(
|
||||
self,
|
||||
input_features: torch.FloatTensor,
|
||||
feature_attention_mask: Optional[torch.LongTensor] = None,
|
||||
audio_feature_lengths: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Encodes audios into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
input_features (`torch.FloatTensor`):
|
||||
The tensors corresponding to the input audios.
|
||||
feature_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
||||
The length of feature shape of each audio in LLM.
|
||||
"""
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||
else:
|
||||
audio_feature_lengths = None
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
)
|
||||
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
audio_outputs = self.audio_tower(
|
||||
input_features,
|
||||
feature_lens=feature_lens,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
|
||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||
|
||||
return audio_features
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2278,11 +2347,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
)
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||
else:
|
||||
audio_feature_lengths = None
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
if (
|
||||
cache_position is None
|
||||
@ -2309,63 +2424,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
)
|
||||
feature_lens = (
|
||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||
)
|
||||
audio_outputs = self.audio_tower(
|
||||
input_features,
|
||||
feature_lens=feature_lens,
|
||||
aftercnn_lens=audio_feat_lengths,
|
||||
)
|
||||
audio_features = audio_outputs.last_hidden_state
|
||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
outputs = self.model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
|
@ -1583,6 +1583,36 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def get_video_features(
|
||||
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input videos.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
return video_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
"""
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1627,8 +1657,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
@ -1645,8 +1674,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
|
@ -646,8 +646,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
@ -664,8 +663,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
|
@ -1508,6 +1508,36 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def get_video_features(
|
||||
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input videos.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
return video_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
"""
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1549,9 +1579,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
@ -1567,9 +1596,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
|
@ -632,6 +632,52 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||
return merged_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||
The attention mask indicating padded regions in the image.
|
||||
"""
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
|
||||
if not any(real_images_inds):
|
||||
# no images, leave one empty image.
|
||||
real_images_inds[0] = True
|
||||
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||
image_hidden_states = image_hidden_states.last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -704,48 +750,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
||||
if pixel_values is not None and image_hidden_states is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||
elif pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
|
||||
if not any(real_images_inds):
|
||||
# no images, leave one empty image.
|
||||
real_images_inds[0] = True
|
||||
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
).last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states)
|
||||
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
|
@ -22,7 +22,7 @@ from torch import nn
|
||||
from ...cache_utils import DynamicCache
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
||||
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
|
||||
from ..idefics3.modeling_idefics3 import (
|
||||
@ -195,6 +195,64 @@ class SmolVLMModel(Idefics3Model):
|
||||
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||
return merged_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The tensors corresponding to the input images.
|
||||
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||
The attention mask indicating padded regions in the image.
|
||||
"""
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
|
||||
if not any(real_images_inds):
|
||||
# no images, leave one empty image.
|
||||
real_images_inds[0] = True
|
||||
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||
image_hidden_states = image_hidden_states.last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
|
||||
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
|
||||
max_num_images is the maximum number of images among the batch_size samples in the batch.
|
||||
Padding images are not needed beyond padding the pixel_values at the entrance of the model.
|
||||
For efficiency, we only pass through the vision_model's forward the real images by
|
||||
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
|
||||
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
|
||||
"""
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -249,48 +307,7 @@ class SmolVLMModel(Idefics3Model):
|
||||
if pixel_values is not None and image_hidden_states is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||
elif pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values
|
||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
|
||||
if not any(real_images_inds):
|
||||
# no images, leave one empty image.
|
||||
real_images_inds[0] = True
|
||||
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
).last_hidden_state
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states)
|
||||
|
||||
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
|
@ -205,8 +205,8 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values_images: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -214,16 +214,25 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
Args:
|
||||
pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
@ -249,7 +258,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -257,7 +266,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input videos.
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
@ -265,6 +274,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
||||
video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`).
|
||||
frames (`int`): Number of frames the videos have.
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
||||
|
||||
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
||||
|
Loading…
Reference in New Issue
Block a user