mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Refactor image features selection in LlaVa (#33696)
* refactor image features selection * break line * remove whitespace * add pr comments: include projection and rename function * make fix-copies * fix get_image_feature in vip llava
This commit is contained in:
parent
22266be970
commit
88d960937c
@ -279,6 +279,21 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def get_image_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
|
||||
):
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
@ -450,17 +465,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
|
@ -282,6 +282,17 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
# Ignore copy
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: list[int]):
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
# For VIP-llava, the image features are computed this way
|
||||
# We select the features from index 1: for the layers -2, -5, -8, -11 and 6
|
||||
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
|
||||
image_features = torch.cat(image_features, dim=-1)
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
@ -451,13 +462,9 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
# For VIP-llava, the image features are computed this way
|
||||
# We select the features from index 1: for the layers -2, -5, -8, -11 and 6
|
||||
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
|
||||
image_features = torch.cat(image_features, dim=-1)
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
|
Loading…
Reference in New Issue
Block a user