VLM: add more modularity (#34175)

* update

* fix tests + fix copies

* fix tests once more
This commit is contained in:
Raushan Turganbay 2024-10-22 07:56:35 +02:00 committed by GitHub
parent 21d5025826
commit 5077bc034f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 493 additions and 258 deletions

View File

@ -273,6 +273,20 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def get_image_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]

View File

@ -705,6 +705,57 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -796,34 +847,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
if pixel_values is not None and pixel_values.size(0) > 0:
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
# figure out if pixel_values is concatenated or stacked
if pixel_values.dim() == 5:
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(

View File

@ -744,6 +744,57 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -883,7 +934,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self._get_image_features(pixel_values, image_sizes)
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
@ -893,7 +949,11 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
video_features = video_feature_lens = None
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self._get_video_features(pixel_values_videos)
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
@ -1080,46 +1140,35 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
return model_inputs
def _get_image_features(self, pixel_values, image_sizes):
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
def get_video_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def _get_video_features(self, pixel_values):
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_video_features = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_features = selected_video_features
# Same as image features except that video has pooling layer
image_features = self.vision_resampler(selected_image_feature)
image_features = self.multi_modal_projector(image_features)
image_features = torch.split(image_features, frames, dim=0)
return image_features
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features

View File

@ -225,7 +225,30 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
self.vision_resampler = LlavaNextVideoPooler(config)
self.post_init()
def _get_image_features(self, pixel_values, image_sizes):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
@ -244,30 +267,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy == "default":
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def _get_video_features(self, pixel_values):
def get_video_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_video_features = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_features = selected_video_features
# Same as image features except that video has pooling layer
image_features = self.vision_resampler(selected_image_feature)
image_features = self.multi_modal_projector(image_features)
image_features = torch.split(image_features, frames, dim=0)
return image_features
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig")
def forward(
@ -407,7 +447,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self._get_image_features(pixel_values, image_sizes)
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
@ -417,7 +462,11 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
video_features = video_feature_lens = None
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self._get_video_features(pixel_values_videos)
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)

View File

@ -481,6 +481,91 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
image_features = image_features.view(batch_frames, -1, dim)
return image_features
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def get_video_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
"""
Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_video_feature = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_feature = selected_video_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_feature = selected_video_feature
video_features = self.multi_modal_projector(selected_video_feature)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
return video_features
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
def forward(
self,
@ -580,35 +665,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
# Images are processed with Anyres
if pixel_values is not None:
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
# unpad extra patches and concatenate them
if pixel_values.dim() == 5:
_pixel_values_list = [
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
# [batch_size*frames*num_patches, num_channels, height, width] where frames=1 for images
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
@ -632,20 +694,14 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
# Video are simply embedded and further pooled to decrease seq len
if pixel_values_videos is not None:
batch_size, frames, channels, height, width = pixel_values_videos.shape
pixel_values_videos = pixel_values_videos.view(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values_videos, output_hidden_states=True)
selected_video_feature = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_feature = selected_video_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_feature = selected_video_feature
video_features = self.multi_modal_projector(selected_video_feature)
video_features = self.apply_pooling(video_features)
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device)
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
image_newline = (
self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device)
)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()

View File

@ -392,6 +392,22 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
)
return causal_mask
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
return image_features
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -477,10 +493,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
# Merge text and images
if pixel_values is not None:
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
image_features = self.get_image_features(pixel_values)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)

View File

@ -23,7 +23,7 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
@ -355,41 +355,59 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
def _get_vision_features(
self,
pixel_values_images: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
if pixel_values_images is None and pixel_values_videos is None:
raise ValueError("You have to specify `pixel_values_images` or `pixel_values_videos`")
def get_image_features(
self, pixel_values_images: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
if pixel_values_videos is not None:
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
Args:
pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1)
if vision_feature_select_strategy == "default":
image_outputs = image_outputs[:, 1:]
elif vision_feature_select_strategy == "full":
image_outputs = image_outputs
else:
video_outputs = None
num_frames = 0
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
if pixel_values_images is not None:
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1)
image_features = self.multi_modal_projector(image_outputs)
if vision_feature_select_strategy == "default":
image_outputs = image_outputs[:, 1:]
elif vision_feature_select_strategy == "full":
image_outputs = image_outputs
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
else:
image_outputs = None
return image_features
return image_outputs, video_outputs, num_frames
def get_video_features(self, pixel_values_videos: torch.FloatTensor, vision_feature_layer: int):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input videos.
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
Returns:
video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`).
frames (`int`): Number of frames the videos have.
"""
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
video_features = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
video_features = self.multi_modal_projector(video_features)
return video_features, num_frames
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@ -534,110 +552,106 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
)
legacy_processing = inputs_not_expanded or pixels_present
if pixel_values_images is not None or pixel_values_videos is not None:
image_outputs, video_outputs, num_frames = self._get_vision_features(
pixel_values_images=pixel_values_images,
pixel_values_videos=pixel_values_videos,
image_features = None
if pixel_values_images is not None:
image_features = self.get_image_features(
pixel_values_images,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
image_features = video_features = None
if image_outputs is not None:
image_features = self.multi_modal_projector(image_outputs)
if video_outputs is not None:
video_features = self.multi_modal_projector(video_outputs)
video_features = None
if pixel_values_videos is not None:
video_features, num_frames = self.get_video_features(
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
)
if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
if input_ids.shape[1] != 1:
for features, frames in ((image_features, 1), (video_features, num_frames)):
if features is not None:
(
inputs_embeds,
attention_mask,
labels,
position_ids,
input_ids,
) = self._merge_input_ids_with_visual_features(
features,
inputs_embeds,
input_ids,
attention_mask,
labels,
num_frames=frames,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
if input_ids.shape[1] != 1:
for features, frames in ((image_features, 1), (video_features, num_frames)):
if features is not None:
(
inputs_embeds,
attention_mask,
labels,
position_ids,
input_ids,
) = self._merge_input_ids_with_visual_features(
features,
inputs_embeds,
input_ids,
attention_mask,
labels,
num_frames=frames,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
if image_outputs is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
else:
if pixel_values_images is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_outputs is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
n_video_features = video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
n_video_features = video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,

View File

@ -275,6 +275,17 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
# Ignore copy
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layers (`List[int]`):
The list og indexes of the layers to select the vision feature.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# For VIP-llava, the image features are computed this way