mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 14:20:04 +06:00
[VLMs] add helpers to get multimodal encodings (#37743)
* add helpers in VLMs * fix tests and copies * fix blip tests * make fix-copies * fix copies * fixup
This commit is contained in:
parent
1e921a3a9c
commit
3ab47b6ce3
@ -1197,7 +1197,7 @@ class AriaModel(AriaPreTrainedModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
pixel_mask: torch.FloatTensor = None,
|
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||||
vision_feature_layer: int = -1,
|
vision_feature_layer: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1208,13 +1208,16 @@ class AriaModel(AriaPreTrainedModel):
|
|||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
pixel_mask (`torch.FloatTensor]`, *optional*):
|
pixel_mask (`torch.FloatTensor]`, *optional*):
|
||||||
The tensors corresponding to the input image mask.
|
The tensors corresponding to the input image mask.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||||
image_outputs = self.vision_tower(
|
image_outputs = self.vision_tower(
|
||||||
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
||||||
|
@ -1325,7 +1325,7 @@ class AriaModel(LlavaModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
pixel_mask: torch.FloatTensor = None,
|
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||||
vision_feature_layer: int = -1,
|
vision_feature_layer: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1336,13 +1336,16 @@ class AriaModel(LlavaModel):
|
|||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
pixel_mask (`torch.FloatTensor]`, *optional*):
|
pixel_mask (`torch.FloatTensor]`, *optional*):
|
||||||
The tensors corresponding to the input image mask.
|
The tensors corresponding to the input image mask.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||||
image_outputs = self.vision_tower(
|
image_outputs = self.vision_tower(
|
||||||
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
|
||||||
|
@ -213,8 +213,8 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -223,16 +223,25 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||||
|
|
||||||
|
@ -1956,6 +1956,50 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
if hasattr(self.language_model, "_hf_hook"):
|
if hasattr(self.language_model, "_hf_hook"):
|
||||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
"""
|
||||||
|
# step 1: forward the images through the vision encoder,
|
||||||
|
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
|
||||||
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
query_output = query_outputs[0]
|
||||||
|
|
||||||
|
# Qformer is kept in fp32, we downcast the output back if needed
|
||||||
|
if query_output.dtype != image_embeds.dtype:
|
||||||
|
query_output = query_output.to(image_embeds.dtype)
|
||||||
|
|
||||||
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
|
language_model_inputs = self.language_projection(query_output)
|
||||||
|
if return_dict:
|
||||||
|
return language_model_inputs, vision_outputs, query_outputs
|
||||||
|
return language_model_inputs
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -2047,37 +2091,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# step 1: forward the images through the vision encoder,
|
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
||||||
)
|
)
|
||||||
image_embeds = vision_outputs[0]
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
query_output = query_outputs[0]
|
|
||||||
|
|
||||||
# Qformer is kept in fp32, we downcast the output back if needed
|
|
||||||
if query_output.dtype != image_embeds.dtype:
|
|
||||||
query_output = query_output.to(image_embeds.dtype)
|
|
||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
language_model_attention_mask = torch.ones(
|
language_model_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
|
@ -904,6 +904,12 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
|||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
def get_image_tokens(self, pixel_values: torch.FloatTensor):
|
def get_image_tokens(self, pixel_values: torch.FloatTensor):
|
||||||
|
logger.warning(
|
||||||
|
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||||
|
)
|
||||||
|
return self.get_image_featues(pixel_values)
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||||
"""
|
"""
|
||||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||||
@ -957,7 +963,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_tokens = self.get_image_tokens(pixel_values)
|
image_tokens = self.get_image_features(pixel_values)
|
||||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||||
|
@ -1586,6 +1586,12 @@ class Emu3Model(Emu3PreTrainedModel):
|
|||||||
self.text_model.set_input_embeddings(value)
|
self.text_model.set_input_embeddings(value)
|
||||||
|
|
||||||
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||||
|
logger.warning(
|
||||||
|
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||||
|
)
|
||||||
|
return self.get_image_featues(pixel_values)
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||||
"""
|
"""
|
||||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||||
@ -1662,7 +1668,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
image_tokens = self.get_image_features(pixel_values, image_sizes)
|
||||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||||
|
@ -941,6 +941,12 @@ class Emu3Model(Emu3PreTrainedModel):
|
|||||||
self.text_model.set_input_embeddings(value)
|
self.text_model.set_input_embeddings(value)
|
||||||
|
|
||||||
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||||
|
logger.warning(
|
||||||
|
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||||
|
)
|
||||||
|
return self.get_image_featues(pixel_values)
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||||
"""
|
"""
|
||||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||||
@ -1017,7 +1023,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
image_tokens = self.get_image_features(pixel_values, image_sizes)
|
||||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||||
|
@ -125,9 +125,25 @@ class FuyuModel(FuyuPreTrainedModel):
|
|||||||
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
|
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
|
||||||
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
|
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
|
||||||
)
|
)
|
||||||
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
|
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
|
||||||
|
output_embeddings.device
|
||||||
|
)
|
||||||
return output_embeddings
|
return output_embeddings
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
"""
|
||||||
|
patch_embeddings = [
|
||||||
|
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
|
||||||
|
for patch in pixel_values
|
||||||
|
]
|
||||||
|
return patch_embeddings
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -185,12 +201,7 @@ class FuyuModel(FuyuPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||||
if image_patches is not None and past_key_values is None:
|
if image_patches is not None and past_key_values is None:
|
||||||
patch_embeddings = [
|
patch_embeddings = self.get_image_features(image_patches)
|
||||||
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
|
|
||||||
.squeeze(0)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
for patch in image_patches
|
|
||||||
]
|
|
||||||
inputs_embeds = self.gather_continuous_embeddings(
|
inputs_embeds = self.gather_continuous_embeddings(
|
||||||
word_embeddings=inputs_embeds,
|
word_embeddings=inputs_embeds,
|
||||||
continuous_embeddings=patch_embeddings,
|
continuous_embeddings=patch_embeddings,
|
||||||
|
@ -961,13 +961,57 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
|||||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||||
"""
|
"""
|
||||||
num_images, _, vision_hidden_size = image_hidden_states.shape
|
|
||||||
special_image_token_mask = input_ids == self.image_token_id
|
special_image_token_mask = input_ids == self.image_token_id
|
||||||
new_inputs_embeds = inputs_embeds.clone()
|
new_inputs_embeds = inputs_embeds.clone()
|
||||||
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
|
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device)
|
||||||
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device)
|
|
||||||
return new_inputs_embeds
|
return new_inputs_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||||
|
The attention mask indicating padded regions in the image.
|
||||||
|
"""
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||||
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||||
|
image_hidden_states = image_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
|
||||||
|
)
|
||||||
|
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring(
|
@auto_docstring(
|
||||||
custom_intro="""
|
custom_intro="""
|
||||||
@ -1052,45 +1096,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
|||||||
if pixel_values is not None and image_hidden_states is not None:
|
if pixel_values is not None and image_hidden_states is not None:
|
||||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||||
elif pixel_values is not None:
|
elif pixel_values is not None:
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask/pP p
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
).last_hidden_state
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(
|
|
||||||
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif image_hidden_states is not None:
|
elif image_hidden_states is not None:
|
||||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||||
|
|
||||||
|
@ -692,16 +692,59 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
|||||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||||
"""
|
"""
|
||||||
num_images, _, vision_hidden_size = image_hidden_states.shape
|
|
||||||
special_image_token_mask = input_ids == self.image_token_id
|
special_image_token_mask = input_ids == self.image_token_id
|
||||||
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||||
new_inputs_embeds = inputs_embeds.clone()
|
new_inputs_embeds = inputs_embeds.clone()
|
||||||
reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
|
|
||||||
# cast to the dtype of the input_embeds to support quantized models
|
# cast to the dtype of the input_embeds to support quantized models
|
||||||
reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
|
new_inputs_embeds[special_image_token_mask] = image_hidden_states
|
||||||
return new_inputs_embeds
|
return new_inputs_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||||
|
The attention mask indicating padded regions in the image.
|
||||||
|
"""
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||||
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||||
|
image_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
|
||||||
|
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring(
|
@auto_docstring(
|
||||||
custom_intro="""
|
custom_intro="""
|
||||||
@ -774,43 +817,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
|||||||
if pixel_values is not None and image_hidden_states is not None:
|
if pixel_values is not None and image_hidden_states is not None:
|
||||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||||
elif pixel_values is not None:
|
elif pixel_values is not None:
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||||
pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
).last_hidden_state
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(image_hidden_states)
|
|
||||||
|
|
||||||
elif image_hidden_states is not None:
|
elif image_hidden_states is not None:
|
||||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||||
|
|
||||||
|
@ -1466,6 +1466,55 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
if hasattr(self.language_model, "_hf_hook"):
|
if hasattr(self.language_model, "_hf_hook"):
|
||||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
qformer_input_ids: torch.LongTensor,
|
||||||
|
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
"""
|
||||||
|
# step 1: forward the images through the vision encoder,
|
||||||
|
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
|
||||||
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
if qformer_attention_mask is None:
|
||||||
|
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||||
|
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
input_ids=qformer_input_ids,
|
||||||
|
attention_mask=qformer_attention_mask,
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||||
|
|
||||||
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
|
language_model_inputs = self.language_projection(query_output)
|
||||||
|
if return_dict:
|
||||||
|
return language_model_inputs, vision_outputs, query_outputs
|
||||||
|
return language_model_inputs
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
@ -1555,40 +1604,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# step 1: forward the images through the vision encoder,
|
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
pixel_values,
|
||||||
vision_outputs = self.vision_model(
|
qformer_input_ids=qformer_input_ids,
|
||||||
pixel_values=pixel_values,
|
qformer_attention_mask=qformer_attention_mask,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
image_embeds = vision_outputs[0]
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
if qformer_attention_mask is None:
|
|
||||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
||||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
input_ids=qformer_input_ids,
|
|
||||||
attention_mask=qformer_attention_mask,
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
|
||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
language_model_attention_mask = torch.ones(
|
language_model_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
@ -1690,30 +1714,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
self._preprocess_accelerate()
|
self._preprocess_accelerate()
|
||||||
|
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
image_embeds = self.vision_model(
|
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
return_dict=True,
|
qformer_input_ids=qformer_input_ids,
|
||||||
|
qformer_attention_mask=qformer_attention_mask,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
).last_hidden_state
|
|
||||||
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
if qformer_attention_mask is None:
|
|
||||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
||||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
input_ids=qformer_input_ids,
|
|
||||||
attention_mask=qformer_attention_mask,
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
|
||||||
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
language_attention_mask = torch.ones(
|
language_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
@ -1722,7 +1729,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
start_tokens = [self.config.text_config.bos_token_id]
|
start_tokens = [self.config.text_config.bos_token_id]
|
||||||
if getattr(self.config, "image_token_id", None) is not None:
|
if getattr(self.config, "image_token_id", None) is not None:
|
||||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
@ -1470,6 +1470,23 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
if hasattr(self.language_model, "_hf_hook"):
|
if hasattr(self.language_model, "_hf_hook"):
|
||||||
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
qformer_input_ids: torch.LongTensor,
|
||||||
|
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
@ -1582,50 +1599,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# step 1: forward the images through the vision encoder,
|
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
pixel_values,
|
||||||
batch_size, frames, channel, height, width = pixel_values.shape
|
qformer_input_ids=qformer_input_ids,
|
||||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
qformer_attention_mask=qformer_attention_mask,
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
image_embeds = vision_outputs[0]
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
if qformer_attention_mask is None:
|
|
||||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
||||||
|
|
||||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
input_ids=qformer_input_ids,
|
|
||||||
attention_mask=qformer_attention_mask,
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
|
||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
|
|
||||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
|
||||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
|
||||||
language_model_attention_mask = torch.ones(
|
language_model_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
@ -1726,39 +1708,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
# preprocess for `accelerate`
|
# preprocess for `accelerate`
|
||||||
self._preprocess_accelerate()
|
self._preprocess_accelerate()
|
||||||
|
|
||||||
# we process in a batched way, later unbatch it back (video has frames=4)
|
batch_size = pixel_values.shape[0]
|
||||||
batch_size, frames, channel, height, width = pixel_values.shape
|
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
|
||||||
|
|
||||||
image_embeds = self.vision_model(
|
|
||||||
pixel_values,
|
pixel_values,
|
||||||
return_dict=True,
|
qformer_input_ids=qformer_input_ids,
|
||||||
|
qformer_attention_mask=qformer_attention_mask,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
).last_hidden_state
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
if qformer_attention_mask is None:
|
|
||||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
||||||
|
|
||||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
input_ids=qformer_input_ids,
|
|
||||||
attention_mask=qformer_attention_mask,
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
|
||||||
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
|
|
||||||
# unbatch the embeddings back by moving frames to seq-len
|
|
||||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
|
||||||
language_attention_mask = torch.ones(
|
language_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
@ -1767,7 +1725,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
start_tokens = [self.config.text_config.bos_token_id]
|
start_tokens = [self.config.text_config.bos_token_id]
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
if getattr(self.config, "video_token_id", None) is not None:
|
||||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@ -1807,6 +1765,65 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def get_video_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
qformer_input_ids: torch.LongTensor,
|
||||||
|
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
"""
|
||||||
|
# step 1: forward the images through the vision encoder,
|
||||||
|
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||||
|
batch_size, frames, channel, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
|
||||||
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
if qformer_attention_mask is None:
|
||||||
|
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||||
|
|
||||||
|
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||||
|
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||||
|
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
input_ids=qformer_input_ids,
|
||||||
|
attention_mask=qformer_attention_mask,
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||||
|
|
||||||
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
|
language_model_inputs = self.language_projection(query_output)
|
||||||
|
|
||||||
|
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||||
|
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||||
|
if return_dict:
|
||||||
|
return language_model_inputs, vision_outputs, query_outputs
|
||||||
|
return language_model_inputs
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InstructBlipVideoVisionModel",
|
"InstructBlipVideoVisionModel",
|
||||||
|
@ -295,6 +295,76 @@ class InstructBlipVideoModel(InstructBlipModel):
|
|||||||
|
|
||||||
|
|
||||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
||||||
|
def get_video_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
qformer_input_ids: torch.LongTensor,
|
||||||
|
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
"""
|
||||||
|
# step 1: forward the images through the vision encoder,
|
||||||
|
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||||
|
batch_size, frames, channel, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||||
|
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
image_embeds = vision_outputs[0]
|
||||||
|
|
||||||
|
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||||
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||||
|
|
||||||
|
if qformer_attention_mask is None:
|
||||||
|
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||||
|
|
||||||
|
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||||
|
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||||
|
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||||
|
query_outputs = self.qformer(
|
||||||
|
input_ids=qformer_input_ids,
|
||||||
|
attention_mask=qformer_attention_mask,
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||||
|
|
||||||
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
|
language_model_inputs = self.language_projection(query_output)
|
||||||
|
|
||||||
|
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||||
|
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||||
|
if return_dict:
|
||||||
|
return language_model_inputs, vision_outputs, query_outputs
|
||||||
|
return language_model_inputs
|
||||||
|
|
||||||
|
# Model supports only videos
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
qformer_input_ids: torch.LongTensor,
|
||||||
|
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
return_dict: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
@ -370,50 +440,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# step 1: forward the images through the vision encoder,
|
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
pixel_values,
|
||||||
batch_size, frames, channel, height, width = pixel_values.shape
|
qformer_input_ids=qformer_input_ids,
|
||||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
qformer_attention_mask=qformer_attention_mask,
|
||||||
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
return_dict=True,
|
||||||
)
|
)
|
||||||
image_embeds = vision_outputs[0]
|
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
|
||||||
|
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
|
||||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
if qformer_attention_mask is None:
|
|
||||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
||||||
|
|
||||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
input_ids=qformer_input_ids,
|
|
||||||
attention_mask=qformer_attention_mask,
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
|
||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
|
|
||||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
|
||||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
|
||||||
language_model_attention_mask = torch.ones(
|
language_model_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
@ -514,39 +549,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
# preprocess for `accelerate`
|
# preprocess for `accelerate`
|
||||||
self._preprocess_accelerate()
|
self._preprocess_accelerate()
|
||||||
|
|
||||||
# we process in a batched way, later unbatch it back (video has frames=4)
|
batch_size = pixel_values.shape[0]
|
||||||
batch_size, frames, channel, height, width = pixel_values.shape
|
language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
|
||||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
|
||||||
|
|
||||||
image_embeds = self.vision_model(
|
|
||||||
pixel_values,
|
pixel_values,
|
||||||
return_dict=True,
|
qformer_input_ids=qformer_input_ids,
|
||||||
|
qformer_attention_mask=qformer_attention_mask,
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
).last_hidden_state
|
|
||||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
||||||
if qformer_attention_mask is None:
|
|
||||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
|
||||||
|
|
||||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
|
||||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
|
||||||
query_outputs = self.qformer(
|
|
||||||
input_ids=qformer_input_ids,
|
|
||||||
attention_mask=qformer_attention_mask,
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_attention_mask,
|
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
|
||||||
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
|
||||||
|
|
||||||
# unbatch the embeddings back by moving frames to seq-len
|
|
||||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
|
||||||
language_attention_mask = torch.ones(
|
language_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
)
|
)
|
||||||
@ -555,7 +566,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
start_tokens = [self.config.text_config.bos_token_id]
|
start_tokens = [self.config.text_config.bos_token_id]
|
||||||
if getattr(self.config, "video_token_id", None) is not None:
|
if getattr(self.config, "video_token_id", None) is not None:
|
||||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||||
input_ids = input_ids.repeat(batch_size, 1)
|
input_ids = input_ids.repeat(batch_size, 1)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
@ -1625,6 +1625,37 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.text_model.model.embed_tokens = value
|
self.text_model.model.embed_tokens = value
|
||||||
|
|
||||||
|
def get_image_features(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
return_attentions: Optional[bool] = False,
|
||||||
|
interpolate_pos_encoding: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
return_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to return `projection_attentions` or not.
|
||||||
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to interpolate positional embeddings or not.
|
||||||
|
"""
|
||||||
|
vision_model_output = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||||
|
)
|
||||||
|
# The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
|
||||||
|
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
|
||||||
|
# normalized features
|
||||||
|
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
|
||||||
|
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
|
||||||
|
|
||||||
|
if return_attentions:
|
||||||
|
return image_embeds, projection_attentions
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
@ -1696,19 +1727,9 @@ class Kosmos2Model(Kosmos2PreTrainedModel):
|
|||||||
if image_embeds is None:
|
if image_embeds is None:
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
|
raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
|
||||||
|
image_embeds, projection_attentions = self.get_image_features(
|
||||||
vision_model_output = self.vision_model(
|
pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
pixel_values=pixel_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
)
|
||||||
# The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
|
|
||||||
image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
|
|
||||||
# normalized features
|
|
||||||
image_embeds = nn.functional.normalize(image_embeds, dim=-1)
|
|
||||||
image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
|
|
||||||
|
|
||||||
outputs = self.text_model(
|
outputs = self.text_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -183,8 +183,8 @@ class LlavaModel(LlavaPreTrainedModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -193,16 +193,25 @@ class LlavaModel(LlavaPreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||||
|
|
||||||
|
@ -365,8 +365,8 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
image_sizes: torch.Tensor,
|
image_sizes: torch.Tensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -376,17 +376,26 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
|||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
Actual image size of each images (H, W).
|
Actual image size of each images (H, W).
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
# ! infer image_num_patches from image_sizes
|
# ! infer image_num_patches from image_sizes
|
||||||
image_num_patches = [
|
image_num_patches = [
|
||||||
image_size_to_num_patches(
|
image_size_to_num_patches(
|
||||||
|
@ -419,8 +419,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
image_sizes: torch.Tensor,
|
image_sizes: torch.Tensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -430,17 +430,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
|||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
Actual image size of each images (H, W).
|
Actual image size of each images (H, W).
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
# ! infer image_num_patches from image_sizes
|
# ! infer image_num_patches from image_sizes
|
||||||
image_num_patches = [
|
image_num_patches = [
|
||||||
image_size_to_num_patches(
|
image_size_to_num_patches(
|
||||||
@ -600,8 +609,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
|||||||
def get_video_features(
|
def get_video_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -609,17 +618,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||||
The tensors corresponding to the input video.
|
The tensors corresponding to the input video.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optiona;*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, frames, channels, height, width = pixel_values.shape
|
batch_size, frames, channels, height, width = pixel_values.shape
|
||||||
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
||||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
@ -252,8 +252,8 @@ class LlavaNextVideoModel(LlavaNextModel):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
image_sizes: torch.Tensor,
|
image_sizes: torch.Tensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -263,17 +263,26 @@ class LlavaNextVideoModel(LlavaNextModel):
|
|||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
Actual image size of each images (H, W).
|
Actual image size of each images (H, W).
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
# ! infer image_num_patches from image_sizes
|
# ! infer image_num_patches from image_sizes
|
||||||
image_num_patches = [
|
image_num_patches = [
|
||||||
image_size_to_num_patches(
|
image_size_to_num_patches(
|
||||||
@ -311,8 +320,8 @@ class LlavaNextVideoModel(LlavaNextModel):
|
|||||||
def get_video_features(
|
def get_video_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -320,17 +329,26 @@ class LlavaNextVideoModel(LlavaNextModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||||
The tensors corresponding to the input video.
|
The tensors corresponding to the input video.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optiona;*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
batch_size, frames, channels, height, width = pixel_values.shape
|
batch_size, frames, channels, height, width = pixel_values.shape
|
||||||
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
||||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
|
@ -419,8 +419,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
image_sizes: torch.Tensor,
|
image_sizes: torch.Tensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -430,17 +430,26 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
|||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||||
Actual image size of each images (H, W).
|
Actual image size of each images (H, W).
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
# ! infer image_num_patches from image_sizes
|
# ! infer image_num_patches from image_sizes
|
||||||
image_num_patches = [
|
image_num_patches = [
|
||||||
image_size_to_num_patches(
|
image_size_to_num_patches(
|
||||||
|
@ -255,8 +255,8 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
|
||||||
image_sizes: torch.Tensor,
|
image_sizes: torch.Tensor,
|
||||||
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -265,15 +265,19 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
image_sizes (`torch.Tensor`):
|
image_sizes (`torch.Tensor`, *optional*):
|
||||||
Tensor containing the image sizes as returned by the processor.
|
Tensor containing the image sizes as returned by the processor.
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||||
|
@ -135,8 +135,8 @@ class Mistral3Model(LlavaModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
|
||||||
image_sizes: torch.Tensor,
|
image_sizes: torch.Tensor,
|
||||||
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -145,15 +145,19 @@ class Mistral3Model(LlavaModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
image_sizes (`torch.Tensor`):
|
image_sizes (`torch.Tensor`, *optional*):
|
||||||
Tensor containing the image sizes as returned by the processor.
|
Tensor containing the image sizes as returned by the processor.
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||||
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
||||||
|
@ -2187,6 +2187,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.model.set_input_embeddings(value)
|
self.model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_video_features(
|
||||||
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input videos.
|
||||||
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each image in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values = pixel_values.type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def get_audio_features(
|
||||||
|
self,
|
||||||
|
input_features: torch.FloatTensor,
|
||||||
|
feature_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes audios into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_features (`torch.FloatTensor`):
|
||||||
|
The tensors corresponding to the input audios.
|
||||||
|
feature_attention_mask (`torch.LongTensor`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||||
|
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
||||||
|
The length of feature shape of each audio in LLM.
|
||||||
|
"""
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
|
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||||
|
else:
|
||||||
|
audio_feature_lengths = None
|
||||||
|
|
||||||
|
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
||||||
|
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||||
|
)
|
||||||
|
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||||
|
audio_outputs = self.audio_tower(
|
||||||
|
input_features,
|
||||||
|
feature_lens=feature_lens,
|
||||||
|
aftercnn_lens=audio_feat_lengths,
|
||||||
|
)
|
||||||
|
audio_features = audio_outputs.last_hidden_state
|
||||||
|
|
||||||
|
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||||
|
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||||
|
|
||||||
|
return audio_features
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -2284,11 +2353,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
# 1. Extract the input embeddings
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
# 2. Merge text , audios , image and video
|
||||||
|
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||||
|
if input_features is not None:
|
||||||
|
audio_features = self.get_audio_features(
|
||||||
|
input_features,
|
||||||
|
feature_attention_mask=feature_attention_mask,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
)
|
||||||
|
audio_mask = (
|
||||||
|
(input_ids == self.config.audio_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
|
image_mask = (
|
||||||
|
(input_ids == self.config.image_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
|
video_mask = (
|
||||||
|
(input_ids == self.config.video_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
if feature_attention_mask is not None:
|
if feature_attention_mask is not None:
|
||||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
|
||||||
else:
|
else:
|
||||||
audio_feature_lengths = None
|
audio_feature_lengths = None
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
if (
|
if (
|
||||||
cache_position is None
|
cache_position is None
|
||||||
@ -2315,63 +2430,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
position_ids = position_ids.add(delta)
|
position_ids = position_ids.add(delta)
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
# 1. Extract the input embeddings
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
# 2. Merge text , audios , image and video
|
|
||||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
|
||||||
if input_features is not None:
|
|
||||||
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
|
||||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
|
||||||
)
|
|
||||||
feature_lens = (
|
|
||||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
|
||||||
)
|
|
||||||
audio_outputs = self.audio_tower(
|
|
||||||
input_features,
|
|
||||||
feature_lens=feature_lens,
|
|
||||||
aftercnn_lens=audio_feat_lengths,
|
|
||||||
)
|
|
||||||
audio_features = audio_outputs.last_hidden_state
|
|
||||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
|
||||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
|
||||||
audio_mask = (
|
|
||||||
(input_ids == self.config.audio_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.type(self.visual.dtype)
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
image_mask = (
|
|
||||||
(input_ids == self.config.image_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
video_mask = (
|
|
||||||
(input_ids == self.config.video_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -2181,6 +2181,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.model.set_input_embeddings(value)
|
self.model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_video_features(
|
||||||
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input videos.
|
||||||
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each image in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values = pixel_values.type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def get_audio_features(
|
||||||
|
self,
|
||||||
|
input_features: torch.FloatTensor,
|
||||||
|
feature_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes audios into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_features (`torch.FloatTensor`):
|
||||||
|
The tensors corresponding to the input audios.
|
||||||
|
feature_attention_mask (`torch.LongTensor`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
|
||||||
|
audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
|
||||||
|
The length of feature shape of each audio in LLM.
|
||||||
|
"""
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
|
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||||
|
else:
|
||||||
|
audio_feature_lengths = None
|
||||||
|
|
||||||
|
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
||||||
|
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||||
|
)
|
||||||
|
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
||||||
|
audio_outputs = self.audio_tower(
|
||||||
|
input_features,
|
||||||
|
feature_lens=feature_lens,
|
||||||
|
aftercnn_lens=audio_feat_lengths,
|
||||||
|
)
|
||||||
|
audio_features = audio_outputs.last_hidden_state
|
||||||
|
|
||||||
|
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
||||||
|
raise ValueError("length of audio_features should match audio_output_lengths")
|
||||||
|
|
||||||
|
return audio_features
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -2278,11 +2347,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
# 1. Extract the input embeddings
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
# 2. Merge text , audios , image and video
|
||||||
|
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||||
|
if input_features is not None:
|
||||||
|
audio_features = self.get_audio_features(
|
||||||
|
input_features,
|
||||||
|
feature_attention_mask=feature_attention_mask,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
)
|
||||||
|
audio_mask = (
|
||||||
|
(input_ids == self.config.audio_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
|
image_mask = (
|
||||||
|
(input_ids == self.config.image_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
|
video_mask = (
|
||||||
|
(input_ids == self.config.video_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
if feature_attention_mask is not None:
|
if feature_attention_mask is not None:
|
||||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
|
||||||
else:
|
else:
|
||||||
audio_feature_lengths = None
|
audio_feature_lengths = None
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
if attention_mask is not None and position_ids is None:
|
||||||
if (
|
if (
|
||||||
cache_position is None
|
cache_position is None
|
||||||
@ -2309,63 +2424,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
position_ids = position_ids.add(delta)
|
position_ids = position_ids.add(delta)
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
# 1. Extract the input embeddings
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
# 2. Merge text , audios , image and video
|
|
||||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
|
||||||
if input_features is not None:
|
|
||||||
audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
|
||||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
|
||||||
)
|
|
||||||
feature_lens = (
|
|
||||||
audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
|
|
||||||
)
|
|
||||||
audio_outputs = self.audio_tower(
|
|
||||||
input_features,
|
|
||||||
feature_lens=feature_lens,
|
|
||||||
aftercnn_lens=audio_feat_lengths,
|
|
||||||
)
|
|
||||||
audio_features = audio_outputs.last_hidden_state
|
|
||||||
if audio_features.shape[0] != sum(audio_output_lengths.tolist()):
|
|
||||||
raise ValueError("length of audio_features should match audio_output_lengths")
|
|
||||||
audio_mask = (
|
|
||||||
(input_ids == self.config.audio_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.type(self.visual.dtype)
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
image_mask = (
|
|
||||||
(input_ids == self.config.image_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
video_mask = (
|
|
||||||
(input_ids == self.config.video_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1583,6 +1583,36 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
|
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
def get_video_features(
|
||||||
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input videos.
|
||||||
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each image in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values = pixel_values.type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1627,8 +1657,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.type(self.visual.dtype)
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
n_image_features = image_embeds.shape[0]
|
n_image_features = image_embeds.shape[0]
|
||||||
if n_image_tokens != n_image_features:
|
if n_image_tokens != n_image_features:
|
||||||
@ -1645,8 +1674,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
n_video_features = video_embeds.shape[0]
|
n_video_features = video_embeds.shape[0]
|
||||||
if n_video_tokens != n_video_features:
|
if n_video_tokens != n_video_features:
|
||||||
|
@ -646,8 +646,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.type(self.visual.dtype)
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
n_image_features = image_embeds.shape[0]
|
n_image_features = image_embeds.shape[0]
|
||||||
if n_image_tokens != n_image_features:
|
if n_image_tokens != n_image_features:
|
||||||
@ -664,8 +663,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
n_video_features = video_embeds.shape[0]
|
n_video_features = video_embeds.shape[0]
|
||||||
if n_video_tokens != n_video_features:
|
if n_video_tokens != n_video_features:
|
||||||
|
@ -1508,6 +1508,36 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
|
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
|
def get_video_features(
|
||||||
|
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encodes videos into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input videos.
|
||||||
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each video in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
|
The temporal, height and width of feature shape of each image in LLM.
|
||||||
|
"""
|
||||||
|
pixel_values = pixel_values.type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1549,9 +1579,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
|
||||||
n_image_features = image_embeds.shape[0]
|
n_image_features = image_embeds.shape[0]
|
||||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1567,9 +1596,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
|
||||||
n_video_features = video_embeds.shape[0]
|
n_video_features = video_embeds.shape[0]
|
||||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -632,6 +632,52 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
|||||||
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||||
return merged_embeds
|
return merged_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||||
|
The attention mask indicating padded regions in the image.
|
||||||
|
"""
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||||
|
|
||||||
|
if not any(real_images_inds):
|
||||||
|
# no images, leave one empty image.
|
||||||
|
real_images_inds[0] = True
|
||||||
|
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||||
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||||
|
image_hidden_states = image_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(image_hidden_states)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring(
|
@auto_docstring(
|
||||||
custom_intro="""
|
custom_intro="""
|
||||||
@ -704,48 +750,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel):
|
|||||||
if pixel_values is not None and image_hidden_states is not None:
|
if pixel_values is not None and image_hidden_states is not None:
|
||||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||||
elif pixel_values is not None:
|
elif pixel_values is not None:
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||||
pixel_values = pixel_values
|
|
||||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
|
||||||
|
|
||||||
if not any(real_images_inds):
|
|
||||||
# no images, leave one empty image.
|
|
||||||
real_images_inds[0] = True
|
|
||||||
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
).last_hidden_state
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(image_hidden_states)
|
|
||||||
|
|
||||||
elif image_hidden_states is not None:
|
elif image_hidden_states is not None:
|
||||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ from torch import nn
|
|||||||
from ...cache_utils import DynamicCache
|
from ...cache_utils import DynamicCache
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import logging
|
from ...utils import auto_docstring, can_return_tuple, logging
|
||||||
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
|
||||||
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
|
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
|
||||||
from ..idefics3.modeling_idefics3 import (
|
from ..idefics3.modeling_idefics3 import (
|
||||||
@ -195,6 +195,64 @@ class SmolVLMModel(Idefics3Model):
|
|||||||
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
|
||||||
return merged_embeds
|
return merged_embeds
|
||||||
|
|
||||||
|
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||||
|
"""
|
||||||
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||||
|
The tensors corresponding to the input images.
|
||||||
|
pixel_attention_mask (`torch.LongTensor`, *optional*):
|
||||||
|
The attention mask indicating padded regions in the image.
|
||||||
|
"""
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||||
|
|
||||||
|
if not any(real_images_inds):
|
||||||
|
# no images, leave one empty image.
|
||||||
|
real_images_inds[0] = True
|
||||||
|
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
|
||||||
|
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
||||||
|
image_hidden_states = image_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(image_hidden_states)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
|
||||||
|
the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
|
||||||
|
max_num_images is the maximum number of images among the batch_size samples in the batch.
|
||||||
|
Padding images are not needed beyond padding the pixel_values at the entrance of the model.
|
||||||
|
For efficiency, we only pass through the vision_model's forward the real images by
|
||||||
|
discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
|
||||||
|
image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
|
||||||
|
"""
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
@ -249,48 +307,7 @@ class SmolVLMModel(Idefics3Model):
|
|||||||
if pixel_values is not None and image_hidden_states is not None:
|
if pixel_values is not None and image_hidden_states is not None:
|
||||||
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
|
||||||
elif pixel_values is not None:
|
elif pixel_values is not None:
|
||||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
|
||||||
pixel_values = pixel_values
|
|
||||||
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
|
||||||
|
|
||||||
# Remove padding images - padding images are full 0.
|
|
||||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
|
||||||
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
|
||||||
|
|
||||||
if not any(real_images_inds):
|
|
||||||
# no images, leave one empty image.
|
|
||||||
real_images_inds[0] = True
|
|
||||||
|
|
||||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
# Handle the vision attention mask
|
|
||||||
if pixel_attention_mask is None:
|
|
||||||
pixel_attention_mask = torch.ones(
|
|
||||||
size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=pixel_values.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Remove padding images from the mask
|
|
||||||
pixel_attention_mask = pixel_attention_mask.view(
|
|
||||||
batch_size * num_images, *pixel_attention_mask.shape[2:]
|
|
||||||
)
|
|
||||||
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
|
|
||||||
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
|
||||||
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
|
||||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
|
||||||
image_hidden_states = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
patch_attention_mask=patch_attention_mask,
|
|
||||||
).last_hidden_state
|
|
||||||
|
|
||||||
# Modality projection & resampling
|
|
||||||
image_hidden_states = self.connector(image_hidden_states)
|
|
||||||
|
|
||||||
elif image_hidden_states is not None:
|
elif image_hidden_states is not None:
|
||||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||||
|
|
||||||
|
@ -205,8 +205,8 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
|||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values_images: torch.FloatTensor,
|
pixel_values_images: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
vision_feature_select_strategy: str,
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -214,16 +214,25 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||||
The tensors corresponding to the input images.
|
The tensors corresponding to the input images.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
vision_feature_select_strategy (`str`):
|
vision_feature_select_strategy (`str`, *optional*):
|
||||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||||
Can be one of `"default"` or `"full"`
|
Can be one of `"default"` or `"full"`
|
||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
if vision_feature_select_strategy not in ["default", "full"]:
|
if vision_feature_select_strategy not in ["default", "full"]:
|
||||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||||
|
|
||||||
@ -249,7 +258,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
|||||||
def get_video_features(
|
def get_video_features(
|
||||||
self,
|
self,
|
||||||
pixel_values_videos: torch.FloatTensor,
|
pixel_values_videos: torch.FloatTensor,
|
||||||
vision_feature_layer: Union[int, List[int]],
|
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||||
@ -257,7 +266,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
|||||||
Args:
|
Args:
|
||||||
pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||||
The tensors corresponding to the input videos.
|
The tensors corresponding to the input videos.
|
||||||
vision_feature_layer (`Union[int, List[int]]`):
|
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||||
the vision feature of the corresponding indices will be concatenated to form the
|
the vision feature of the corresponding indices will be concatenated to form the
|
||||||
vision features.
|
vision features.
|
||||||
@ -265,6 +274,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel):
|
|||||||
video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`).
|
video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`).
|
||||||
frames (`int`): Number of frames the videos have.
|
frames (`int`): Number of frames the videos have.
|
||||||
"""
|
"""
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
||||||
|
|
||||||
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
||||||
|
Loading…
Reference in New Issue
Block a user