Refactor image features selection in LlaVa (#33696)

* refactor image features selection

* break line

* remove whitespace

* add pr comments: include projection and rename function

* make fix-copies

* fix get_image_feature in vip llava
This commit is contained in:
Kenza Bouzid 2024-10-01 13:37:31 +01:00 committed by GitHub
parent 22266be970
commit 88d960937c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 18 deletions

View File

@ -279,6 +279,21 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
@ -450,17 +465,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
if pixel_values is not None:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
if legacy_processing:
logger.warning_once(

View File

@ -282,6 +282,17 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
self.vocab_size = model_embeds.num_embeddings
return model_embeds
# Ignore copy
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: list[int]):
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# For VIP-llava, the image features are computed this way
# We select the features from index 1: for the layers -2, -5, -8, -11 and 6
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
image_features = torch.cat(image_features, dim=-1)
image_features = self.multi_modal_projector(image_features)
return image_features
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
@ -451,13 +462,9 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
if pixel_values is not None:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# For VIP-llava, the image features are computed this way
# We select the features from index 1: for the layers -2, -5, -8, -11 and 6
image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
image_features = torch.cat(image_features, dim=-1)
image_features = self.multi_modal_projector(image_features)
image_features = self.get_image_features(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
)
if legacy_processing:
logger.warning_once(