diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 9a02e1c58d2..0d516c5f1c0 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1197,7 +1197,7 @@ class AriaModel(AriaPreTrainedModel): def get_image_features( self, pixel_values: torch.FloatTensor, - pixel_mask: torch.FloatTensor = None, + pixel_mask: Optional[torch.FloatTensor] = None, vision_feature_layer: int = -1, ): """ @@ -1208,13 +1208,16 @@ class AriaModel(AriaPreTrainedModel): The tensors corresponding to the input images. pixel_mask (`torch.FloatTensor]`, *optional*): The tensors corresponding to the input image mask. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) patch_attention_mask = self._create_patch_attention_mask(pixel_mask) image_outputs = self.vision_tower( pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index c741a9b2c4e..738b269b0bb 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1325,7 +1325,7 @@ class AriaModel(LlavaModel): def get_image_features( self, pixel_values: torch.FloatTensor, - pixel_mask: torch.FloatTensor = None, + pixel_mask: Optional[torch.FloatTensor] = None, vision_feature_layer: int = -1, ): """ @@ -1336,13 +1336,16 @@ class AriaModel(LlavaModel): The tensors corresponding to the input images. pixel_mask (`torch.FloatTensor]`, *optional*): The tensors corresponding to the input image mask. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) patch_attention_mask = self._create_patch_attention_mask(pixel_mask) image_outputs = self.vision_tower( pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 0e002635849..b7d7521ef69 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -213,8 +213,8 @@ class AyaVisionModel(AyaVisionPreTrainedModel): def get_image_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, **kwargs, ): """ @@ -223,16 +223,25 @@ class AyaVisionModel(AyaVisionPreTrainedModel): Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index ff5ebb952af..3ca38af6add 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1956,6 +1956,50 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + def get_image_features( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_outputs = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs[0] + + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + if return_dict: + return language_model_inputs, vision_outputs, query_outputs + return language_model_inputs + @auto_docstring def forward( self, @@ -2047,37 +2091,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # step 1: forward the images through the vision encoder, - # to get image embeddings of shape (batch_size, seq_len, hidden_size) - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - interpolate_pos_encoding=interpolate_pos_encoding, + language_model_inputs, vision_outputs, query_outputs = self.get_image_features( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True ) - image_embeds = vision_outputs[0] - - # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_outputs = self.qformer( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - query_output = query_outputs[0] - - # Qformer is kept in fp32, we downcast the output back if needed - if query_output.dtype != image_embeds.dtype: - query_output = query_output.to(image_embeds.dtype) - - # step 3: use the language model, conditioned on the query outputs and the prompt - language_model_inputs = self.language_projection(query_output) + vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs + query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs language_model_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 624450371bd..8b5d2a46067 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -904,6 +904,12 @@ class ChameleonModel(ChameleonPreTrainedModel): self.embed_tokens = value def get_image_tokens(self, pixel_values: torch.FloatTensor): + logger.warning( + "`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`" + ) + return self.get_image_featues(pixel_values) + + def get_image_features(self, pixel_values: torch.FloatTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" @@ -957,7 +963,7 @@ class ChameleonModel(ChameleonPreTrainedModel): ) if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values) + image_tokens = self.get_image_features(pixel_values) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 83136bfba2b..61ccecd6501 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1586,6 +1586,12 @@ class Emu3Model(Emu3PreTrainedModel): self.text_model.set_input_embeddings(value) def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + logger.warning( + "`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`" + ) + return self.get_image_featues(pixel_values) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" @@ -1662,7 +1668,7 @@ class Emu3Model(Emu3PreTrainedModel): ) if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values, image_sizes) + image_tokens = self.get_image_features(pixel_values, image_sizes) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index b1c27f961bc..ade55f93a16 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -941,6 +941,12 @@ class Emu3Model(Emu3PreTrainedModel): self.text_model.set_input_embeddings(value) def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + logger.warning( + "`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`" + ) + return self.get_image_featues(pixel_values) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" @@ -1017,7 +1023,7 @@ class Emu3Model(Emu3PreTrainedModel): ) if pixel_values is not None: - image_tokens = self.get_image_tokens(pixel_values, image_sizes) + image_tokens = self.get_image_features(pixel_values, image_sizes) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index ef3cd7309fd..4f120afa1ba 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -125,9 +125,25 @@ class FuyuModel(FuyuPreTrainedModel): f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match " f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}." ) - output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices] + output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to( + output_embeddings.device + ) return output_embeddings + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + patch_embeddings = [ + self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0) + for patch in pixel_values + ] + return patch_embeddings + @auto_docstring def forward( self, @@ -185,12 +201,7 @@ class FuyuModel(FuyuPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: - patch_embeddings = [ - self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)) - .squeeze(0) - .to(inputs_embeds.device) - for patch in image_patches - ] + patch_embeddings = self.get_image_features(image_patches) inputs_embeds = self.gather_continuous_embeddings( word_embeddings=inputs_embeds, continuous_embeddings=patch_embeddings, diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index a07e480d5c3..f7008ad33e8 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -961,13 +961,57 @@ class Idefics2Model(Idefics2PreTrainedModel): - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. """ - num_images, _, vision_hidden_size = image_hidden_states.shape special_image_token_mask = input_ids == self.image_token_id new_inputs_embeds = inputs_embeds.clone() - reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) - new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device) + new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device) return new_inputs_embeds + def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + pixel_attention_mask (`torch.LongTensor`, *optional*): + The attention mask indicating padded regions in the image. + """ + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool() + # Get sequence from the vision encoder + image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + image_hidden_states = image_hidden_states.last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) + ) + image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) + return image_hidden_states + @can_return_tuple @auto_docstring( custom_intro=""" @@ -1052,45 +1096,7 @@ class Idefics2Model(Idefics2PreTrainedModel): if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask/pP p - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector( - image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) - ) - + image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 90e15821a3f..6bcdd3594e4 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -692,16 +692,59 @@ class Idefics3Model(Idefics3PreTrainedModel): - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. """ - num_images, _, vision_hidden_size = image_hidden_states.shape special_image_token_mask = input_ids == self.image_token_id # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. new_inputs_embeds = inputs_embeds.clone() - reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) # cast to the dtype of the input_embeds to support quantized models - reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) - new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states + image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) + new_inputs_embeds[special_image_token_mask] = image_hidden_states return new_inputs_embeds + def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + pixel_attention_mask (`torch.LongTensor`, *optional*): + The attention mask indicating padded regions in the image. + """ + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + image_hidden_states.last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states.last_hidden_state) + image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) + return image_hidden_states + @can_return_tuple @auto_docstring( custom_intro=""" @@ -774,43 +817,7 @@ class Idefics3Model(Idefics3PreTrainedModel): if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector(image_hidden_states) - + image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index d5aa8efc46b..c90d22f012d 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1466,6 +1466,55 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + def get_image_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + # step 1: forward the images through the vision encoder, + # to get image embeddings of shape (batch_size, seq_len, hidden_size) + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + if return_dict: + return language_model_inputs, vision_outputs, query_outputs + return language_model_inputs + @can_return_tuple @auto_docstring def forward( @@ -1555,40 +1604,15 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # step 1: forward the images through the vision encoder, - # to get image embeddings of shape (batch_size, seq_len, hidden_size) - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + language_model_inputs, vision_outputs, query_outputs = self.get_image_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, ) - image_embeds = vision_outputs[0] - - # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) - if qformer_attention_mask is None: - qformer_attention_mask = torch.ones_like(qformer_input_ids) - qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) - query_outputs = self.qformer( - input_ids=qformer_input_ids, - attention_mask=qformer_attention_mask, - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - query_output = query_outputs[0][:, : query_tokens.size(1), :] - - # step 3: use the language model, conditioned on the query outputs and the prompt - language_model_inputs = self.language_projection(query_output) + vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs + query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs language_model_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -1690,30 +1714,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati self._preprocess_accelerate() batch_size = pixel_values.shape[0] - image_embeds = self.vision_model( + language_model_inputs, vision_outputs, query_outputs = self.get_image_features( pixel_values, - return_dict=True, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, interpolate_pos_encoding=interpolate_pos_encoding, - ).last_hidden_state - - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) - if qformer_attention_mask is None: - qformer_attention_mask = torch.ones_like(qformer_input_ids) - qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) - query_outputs = self.qformer( - input_ids=qformer_input_ids, - attention_mask=qformer_attention_mask, - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, return_dict=True, ) - query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :] - - language_model_inputs = self.language_projection(query_output) language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -1722,7 +1729,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati start_tokens = [self.config.text_config.bos_token_id] if getattr(self.config, "image_token_id", None) is not None: start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) input_ids = input_ids.repeat(batch_size, 1) if attention_mask is None: diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 48eaa4fa85d..b9f40deffef 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1470,6 +1470,23 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + def get_image_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + pass + @can_return_tuple @auto_docstring def forward( @@ -1582,50 +1599,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # step 1: forward the images through the vision encoder, - # we process in a batched way, later unbatch it back (video has frames=4 always) - batch_size, frames, channel, height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + language_model_inputs, vision_outputs, query_outputs = self.get_video_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, ) - image_embeds = vision_outputs[0] - - # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) - - if qformer_attention_mask is None: - qformer_attention_mask = torch.ones_like(qformer_input_ids) - - qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) - qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) - qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) - query_outputs = self.qformer( - input_ids=qformer_input_ids, - attention_mask=qformer_attention_mask, - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - query_output = query_outputs[0][:, : query_tokens.size(1), :] - - # step 3: use the language model, conditioned on the query outputs and the prompt - language_model_inputs = self.language_projection(query_output) - - # unbatch inputs back, each video-frame gets `num_query_tokens` seq length - language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) + vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs + query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs language_model_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -1726,39 +1708,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel # preprocess for `accelerate` self._preprocess_accelerate() - # we process in a batched way, later unbatch it back (video has frames=4) - batch_size, frames, channel, height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) - - image_embeds = self.vision_model( + batch_size = pixel_values.shape[0] + language_model_inputs, vision_outputs, query_outputs = self.get_video_features( pixel_values, - return_dict=True, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, interpolate_pos_encoding=interpolate_pos_encoding, - ).last_hidden_state - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) - if qformer_attention_mask is None: - qformer_attention_mask = torch.ones_like(qformer_input_ids) - - qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) - qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) - qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) - query_outputs = self.qformer( - input_ids=qformer_input_ids, - attention_mask=qformer_attention_mask, - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, return_dict=True, ) - query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :] - language_model_inputs = self.language_projection(query_output) - - # unbatch the embeddings back by moving frames to seq-len - language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -1767,7 +1725,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel start_tokens = [self.config.text_config.bos_token_id] if getattr(self.config, "video_token_id", None) is not None: start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) input_ids = input_ids.repeat(batch_size, 1) if attention_mask is None: @@ -1807,6 +1765,65 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel return outputs + def get_video_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + # step 1: forward the images through the vision encoder, + # we process in a batched way, later unbatch it back (video has frames=4 always) + batch_size, frames, channel, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + + qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) + qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + + # unbatch inputs back, each video-frame gets `num_query_tokens` seq length + language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) + if return_dict: + return language_model_inputs, vision_outputs, query_outputs + return language_model_inputs + __all__ = [ "InstructBlipVideoVisionModel", diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index 44a0022ee4f..079abe579ed 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -295,6 +295,76 @@ class InstructBlipVideoModel(InstructBlipModel): class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): + def get_video_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + """ + # step 1: forward the images through the vision encoder, + # we process in a batched way, later unbatch it back (video has frames=4 always) + batch_size, frames, channel, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, + ) + image_embeds = vision_outputs[0] + + # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) + + # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) + + if qformer_attention_mask is None: + qformer_attention_mask = torch.ones_like(qformer_input_ids) + + qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) + qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) + qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) + query_outputs = self.qformer( + input_ids=qformer_input_ids, + attention_mask=qformer_attention_mask, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=True, + ) + query_output = query_outputs[0][:, : query_tokens.size(1), :] + + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) + + # unbatch inputs back, each video-frame gets `num_query_tokens` seq length + language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) + if return_dict: + return language_model_inputs, vision_outputs, query_outputs + return language_model_inputs + + # Model supports only videos + def get_image_features( + self, + pixel_values: torch.FloatTensor, + qformer_input_ids: torch.LongTensor, + qformer_attention_mask: Optional[torch.LongTensor] = None, + interpolate_pos_encoding: Optional[bool] = False, + return_dict: Optional[bool] = False, + ): + pass + def forward( self, pixel_values: torch.FloatTensor, @@ -370,50 +440,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # step 1: forward the images through the vision encoder, - # we process in a batched way, later unbatch it back (video has frames=4 always) - batch_size, frames, channel, height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) - - vision_outputs = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + language_model_inputs, vision_outputs, query_outputs = self.get_video_features( + pixel_values, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=True, ) - image_embeds = vision_outputs[0] - - # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) - - if qformer_attention_mask is None: - qformer_attention_mask = torch.ones_like(qformer_input_ids) - - qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) - qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) - qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) - query_outputs = self.qformer( - input_ids=qformer_input_ids, - attention_mask=qformer_attention_mask, - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - query_output = query_outputs[0][:, : query_tokens.size(1), :] - - # step 3: use the language model, conditioned on the query outputs and the prompt - language_model_inputs = self.language_projection(query_output) - - # unbatch inputs back, each video-frame gets `num_query_tokens` seq length - language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) + vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs + query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs language_model_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -514,39 +549,15 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera # preprocess for `accelerate` self._preprocess_accelerate() - # we process in a batched way, later unbatch it back (video has frames=4) - batch_size, frames, channel, height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) - - image_embeds = self.vision_model( + batch_size = pixel_values.shape[0] + language_model_inputs, vision_outputs, query_outputs = self.get_video_features( pixel_values, - return_dict=True, + qformer_input_ids=qformer_input_ids, + qformer_attention_mask=qformer_attention_mask, interpolate_pos_encoding=interpolate_pos_encoding, - ).last_hidden_state - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) - - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) - if qformer_attention_mask is None: - qformer_attention_mask = torch.ones_like(qformer_input_ids) - - qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) - qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) - qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) - query_outputs = self.qformer( - input_ids=qformer_input_ids, - attention_mask=qformer_attention_mask, - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_attention_mask, return_dict=True, ) - query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :] - language_model_inputs = self.language_projection(query_output) - - # unbatch the embeddings back by moving frames to seq-len - language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) @@ -555,7 +566,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera start_tokens = [self.config.text_config.bos_token_id] if getattr(self.config, "video_token_id", None) is not None: start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens - input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device) + input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device) input_ids = input_ids.repeat(batch_size, 1) if attention_mask is None: diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 980996b2820..d25f100a234 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1625,6 +1625,37 @@ class Kosmos2Model(Kosmos2PreTrainedModel): def set_input_embeddings(self, value): self.text_model.model.embed_tokens = value + def get_image_features( + self, + pixel_values: torch.FloatTensor, + return_attentions: Optional[bool] = False, + interpolate_pos_encoding: Optional[bool] = False, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + return_attentions (`bool`, *optional*, defaults to `False`): + Whether to return `projection_attentions` or not. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate positional embeddings or not. + """ + vision_model_output = self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. + image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) + # normalized features + image_embeds = nn.functional.normalize(image_embeds, dim=-1) + image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) + + if return_attentions: + return image_embeds, projection_attentions + return image_embeds + @can_return_tuple @auto_docstring def forward( @@ -1696,19 +1727,9 @@ class Kosmos2Model(Kosmos2PreTrainedModel): if image_embeds is None: if pixel_values is None: raise ValueError("You have to specify either `pixel_values` or `image_embeds`.") - - vision_model_output = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, + image_embeds, projection_attentions = self.get_image_features( + pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding ) - # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`. - image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0]) - # normalized features - image_embeds = nn.functional.normalize(image_embeds, dim=-1) - image_embeds, projection_attentions = self.image_to_text_projection(image_embeds) outputs = self.text_model( input_ids=input_ids, diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index ebe0e6c5990..b27624241ac 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -183,8 +183,8 @@ class LlavaModel(LlavaPreTrainedModel): def get_image_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, **kwargs, ): """ @@ -193,16 +193,25 @@ class LlavaModel(LlavaPreTrainedModel): Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 718c1e96c98..fa92bc62236 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -365,8 +365,8 @@ class LlavaNextModel(LlavaNextPreTrainedModel): self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -376,17 +376,26 @@ class LlavaNextModel(LlavaNextPreTrainedModel): The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + # ! infer image_num_patches from image_sizes image_num_patches = [ image_size_to_num_patches( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index b956b6f1bd0..3cb81ada8ac 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -419,8 +419,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -430,17 +430,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + # ! infer image_num_patches from image_sizes image_num_patches = [ image_size_to_num_patches( @@ -600,8 +609,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): def get_video_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains video last hidden states from the vision tower and apply multimodal projection. @@ -609,17 +618,26 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) The tensors corresponding to the input video. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optiona;*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches and are of shape `(num_videos, video_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + batch_size, frames, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width) video_features = self.vision_tower(pixel_values, output_hidden_states=True) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 88ef859ad29..91d16caab6b 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -252,8 +252,8 @@ class LlavaNextVideoModel(LlavaNextModel): self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -263,17 +263,26 @@ class LlavaNextVideoModel(LlavaNextModel): The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + # ! infer image_num_patches from image_sizes image_num_patches = [ image_size_to_num_patches( @@ -311,8 +320,8 @@ class LlavaNextVideoModel(LlavaNextModel): def get_video_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains video last hidden states from the vision tower and apply multimodal projection. @@ -320,17 +329,26 @@ class LlavaNextVideoModel(LlavaNextModel): Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) The tensors corresponding to the input video. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optiona;*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches and are of shape `(num_videos, video_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + batch_size, frames, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width) video_features = self.vision_tower(pixel_values, output_hidden_states=True) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 18475d381a0..be600e8a96e 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -419,8 +419,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -430,17 +430,26 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): The tensors corresponding to the input images. image_sizes (`torch.Tensor` of shape `(num_images, 2)`) Actual image size of each images (H, W). - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + # ! infer image_num_patches from image_sizes image_num_patches = [ image_size_to_num_patches( diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 3c57c816fab..5f0c9604760 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -255,8 +255,8 @@ class Mistral3Model(Mistral3PreTrainedModel): def get_image_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, List[int]]] = None, **kwargs, ): """ @@ -265,15 +265,19 @@ class Mistral3Model(Mistral3PreTrainedModel): Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - image_sizes (`torch.Tensor`): + image_sizes (`torch.Tensor`, *optional*): Tensor containing the image sizes as returned by the processor. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 7494afdddc8..a3b5fa5ecab 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -135,8 +135,8 @@ class Mistral3Model(LlavaModel): def get_image_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], image_sizes: torch.Tensor, + vision_feature_layer: Optional[Union[int, List[int]]] = None, **kwargs, ): """ @@ -145,15 +145,19 @@ class Mistral3Model(LlavaModel): Args: pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - image_sizes (`torch.Tensor`): + image_sizes (`torch.Tensor`, *optional*): Tensor containing the image sizes as returned by the processor. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 926b19a9b71..ce785d7a7b5 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -2187,6 +2187,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + ): + """ + Encodes audios into continuous embeddings that can be forwarded to the language model. + + Args: + input_features (`torch.FloatTensor`): + The tensors corresponding to the input audios. + feature_attention_mask (`torch.LongTensor`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + """ + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + + return audio_features + @auto_docstring def forward( self, @@ -2284,11 +2353,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_mask = ( + (input_ids == self.config.audio_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) - input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) else: audio_feature_lengths = None + if attention_mask is not None and position_ids is None: if ( cache_position is None @@ -2315,63 +2430,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - if inputs_embeds is None: - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - - # 2. Merge text , audios , image and video - if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage - if input_features is not None: - audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( - audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - ) - feature_lens = ( - audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - ) - audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - aftercnn_lens=audio_feat_lengths, - ) - audio_features = audio_outputs.last_hidden_state - if audio_features.shape[0] != sum(audio_output_lengths.tolist()): - raise ValueError("length of audio_features should match audio_output_lengths") - audio_mask = ( - (input_ids == self.config.audio_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) - - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index ef2094e4250..3d55bf10f40 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2181,6 +2181,75 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + ): + """ + Encodes audios into continuous embeddings that can be forwarded to the language model. + + Args: + input_features (`torch.FloatTensor`): + The tensors corresponding to the input audios. + feature_attention_mask (`torch.LongTensor`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + """ + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + + return audio_features + @auto_docstring def forward( self, @@ -2278,11 +2347,57 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_mask = ( + (input_ids == self.config.audio_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) - input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) else: audio_feature_lengths = None + if attention_mask is not None and position_ids is None: if ( cache_position is None @@ -2309,63 +2424,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - if inputs_embeds is None: - # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) - - # 2. Merge text , audios , image and video - if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage - if input_features is not None: - audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( - audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - ) - feature_lens = ( - audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - ) - audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - aftercnn_lens=audio_feat_lengths, - ) - audio_features = audio_outputs.last_hidden_state - if audio_features.shape[0] != sum(audio_output_lengths.tolist()): - raise ValueError("length of audio_features should match audio_output_lengths") - audio_mask = ( - (input_ids == self.config.audio_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) - - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 82ac330dd98..12d98835338 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1583,6 +1583,36 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): return position_ids, mrope_position_deltas + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + @auto_docstring def forward( self, @@ -1627,8 +1657,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_embeds = self.get_image_features(pixel_values, image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: @@ -1645,8 +1674,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 4f9882d2233..0b3fd6ea0bc 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -646,8 +646,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + image_embeds = self.get_image_features(pixel_values, image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: @@ -664,8 +663,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4e78a259154..47e54e76ae8 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1508,6 +1508,36 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): return position_ids, mrope_position_deltas + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + @auto_docstring def forward( self, @@ -1549,9 +1579,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum() + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: raise ValueError( @@ -1567,9 +1596,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum() + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: raise ValueError( diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 0f8fcc025d9..bbf05404ac1 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -632,6 +632,52 @@ class SmolVLMModel(SmolVLMPreTrainedModel): merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) return merged_embeds + def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + pixel_attention_mask (`torch.LongTensor`, *optional*): + The attention mask indicating padded regions in the image. + """ + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + image_hidden_states = image_hidden_states.last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + return image_hidden_states + @can_return_tuple @auto_docstring( custom_intro=""" @@ -704,48 +750,7 @@ class SmolVLMModel(SmolVLMPreTrainedModel): if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - - if not any(real_images_inds): - # no images, leave one empty image. - real_images_inds[0] = True - - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=[pixel_values.shape[i] for i in (0, 2, 3)], - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector(image_hidden_states) - + image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) diff --git a/src/transformers/models/smolvlm/modular_smolvlm.py b/src/transformers/models/smolvlm/modular_smolvlm.py index 95d3c4dadae..b2263db30e4 100644 --- a/src/transformers/models/smolvlm/modular_smolvlm.py +++ b/src/transformers/models/smolvlm/modular_smolvlm.py @@ -22,7 +22,7 @@ from torch import nn from ...cache_utils import DynamicCache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack -from ...utils import logging +from ...utils import auto_docstring, can_return_tuple, logging from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor from ..idefics3.modeling_idefics3 import ( @@ -195,6 +195,64 @@ class SmolVLMModel(Idefics3Model): merged_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds) return merged_embeds + def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + pixel_attention_mask (`torch.LongTensor`, *optional*): + The attention mask indicating padded regions in the image. + """ + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + + if not any(real_images_inds): + # no images, leave one empty image. + real_images_inds[0] = True + + pixel_values = pixel_values[real_images_inds].contiguous() + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=[pixel_values.shape[i] for i in (0, 2, 3)], + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:]) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + image_hidden_states = image_hidden_states.last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + return image_hidden_states + + @can_return_tuple + @auto_docstring( + custom_intro=""" + Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to + the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where + max_num_images is the maximum number of images among the batch_size samples in the batch. + Padding images are not needed beyond padding the pixel_values at the entrance of the model. + For efficiency, we only pass through the vision_model's forward the real images by + discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where + image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3. + """ + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -249,48 +307,7 @@ class SmolVLMModel(Idefics3Model): if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: - batch_size, num_images, num_channels, height, width = pixel_values.shape - pixel_values = pixel_values - pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # Remove padding images - padding images are full 0. - nb_values_per_image = pixel_values.shape[1:].numel() - real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - - if not any(real_images_inds): - # no images, leave one empty image. - real_images_inds[0] = True - - pixel_values = pixel_values[real_images_inds].contiguous() - - # Handle the vision attention mask - if pixel_attention_mask is None: - pixel_attention_mask = torch.ones( - size=[pixel_values.shape[i] for i in (0, 2, 3)], - dtype=torch.bool, - device=pixel_values.device, - ) - else: - # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask.view( - batch_size * num_images, *pixel_attention_mask.shape[2:] - ) - pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() - - patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # Get sequence from the vision encoder - image_hidden_states = self.vision_model( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - ).last_hidden_state - - # Modality projection & resampling - image_hidden_states = self.connector(image_hidden_states) - + image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 8b4718959e8..535c796e8ab 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -205,8 +205,8 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): def get_image_features( self, pixel_values_images: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -214,16 +214,25 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): Args: pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) The tensors corresponding to the input images. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. - vision_feature_select_strategy (`str`): + vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can be one of `"default"` or `"full"` Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + if vision_feature_select_strategy not in ["default", "full"]: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") @@ -249,7 +258,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): def get_video_features( self, pixel_values_videos: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], + vision_feature_layer: Optional[Union[int, List[int]]] = None, ): """ Obtains video last hidden states from the vision tower and apply multimodal projection. @@ -257,7 +266,7 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): Args: pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) The tensors corresponding to the input videos. - vision_feature_layer (`Union[int, List[int]]`): + vision_feature_layer (`Union[int, List[int]]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. @@ -265,6 +274,10 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`). frames (`int`): Number of frames the videos have. """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)