mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
VLM: add more modularity (#34175)
* update * fix tests + fix copies * fix tests once more
This commit is contained in:
parent
21d5025826
commit
5077bc034f
@ -273,6 +273,20 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
def get_image_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
|
@ -705,6 +705,57 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: int,
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -796,34 +847,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
# figure out if pixel_values is concatenated or stacked
|
||||
if pixel_values.dim() == 5:
|
||||
# stacking when input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
|
@ -744,6 +744,57 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: int,
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -883,7 +934,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
|
||||
image_features = feature_lens = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self._get_image_features(pixel_values, image_sizes)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
@ -893,7 +949,11 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
|
||||
video_features = video_feature_lens = None
|
||||
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
|
||||
video_features = self._get_video_features(pixel_values_videos)
|
||||
video_features = self.get_video_features(
|
||||
pixel_values_videos,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
)
|
||||
video_features = [feature.flatten(0, 1) for feature in video_features]
|
||||
video_feature_lens = [feature.size(0) for feature in video_features]
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
@ -1080,46 +1140,35 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _get_image_features(self, pixel_values, image_sizes):
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
def get_video_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
|
||||
):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def _get_video_features(self, pixel_values):
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input video.
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||
"""
|
||||
batch_size, frames, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_video_features = video_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_video_features = selected_video_features[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_video_features = selected_video_features
|
||||
|
||||
# Same as image features except that video has pooling layer
|
||||
image_features = self.vision_resampler(selected_image_feature)
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
image_features = torch.split(image_features, frames, dim=0)
|
||||
return image_features
|
||||
video_features = self.vision_resampler(selected_video_features)
|
||||
video_features = self.multi_modal_projector(video_features)
|
||||
video_features = torch.split(video_features, frames, dim=0)
|
||||
return video_features
|
||||
|
@ -225,7 +225,30 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
self.post_init()
|
||||
|
||||
def _get_image_features(self, pixel_values, image_sizes):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: int,
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
@ -244,30 +267,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def _get_video_features(self, pixel_values):
|
||||
def get_video_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
|
||||
):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input video.
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||
"""
|
||||
batch_size, frames, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_video_features = video_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_video_features = selected_video_features[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_video_features = selected_video_features
|
||||
|
||||
# Same as image features except that video has pooling layer
|
||||
image_features = self.vision_resampler(selected_image_feature)
|
||||
image_features = self.multi_modal_projector(image_features)
|
||||
image_features = torch.split(image_features, frames, dim=0)
|
||||
return image_features
|
||||
video_features = self.vision_resampler(selected_video_features)
|
||||
video_features = self.multi_modal_projector(video_features)
|
||||
video_features = torch.split(video_features, frames, dim=0)
|
||||
return video_features
|
||||
|
||||
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig")
|
||||
def forward(
|
||||
@ -407,7 +447,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
|
||||
image_features = feature_lens = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
image_features = self._get_image_features(pixel_values, image_sizes)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
@ -417,7 +462,11 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
|
||||
video_features = video_feature_lens = None
|
||||
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
|
||||
video_features = self._get_video_features(pixel_values_videos)
|
||||
video_features = self.get_video_features(
|
||||
pixel_values_videos,
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
)
|
||||
video_features = [feature.flatten(0, 1) for feature in video_features]
|
||||
video_feature_lens = [feature.size(0) for feature in video_features]
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
|
@ -481,6 +481,91 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
image_features = image_features.view(batch_frames, -1, dim)
|
||||
return image_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: int,
|
||||
vision_feature_select_strategy: str,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def get_video_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
|
||||
):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower, apply multimodal projection and pooling.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input video.
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_videos, video_length, embed_dim)`).
|
||||
"""
|
||||
batch_size, frames, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(batch_size * frames, channels, height, width)
|
||||
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_video_feature = video_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_video_feature = selected_video_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_video_feature = selected_video_feature
|
||||
video_features = self.multi_modal_projector(selected_video_feature)
|
||||
|
||||
video_features = self.apply_pooling(video_features)
|
||||
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
|
||||
|
||||
return video_features
|
||||
|
||||
@add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -580,35 +665,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
|
||||
# Images are processed with Anyres
|
||||
if pixel_values is not None:
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
]
|
||||
|
||||
# unpad extra patches and concatenate them
|
||||
if pixel_values.dim() == 5:
|
||||
_pixel_values_list = [
|
||||
pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
|
||||
]
|
||||
# [batch_size*frames*num_patches, num_channels, height, width] where frames=1 for images
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
@ -632,20 +694,14 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
|
||||
# Video are simply embedded and further pooled to decrease seq len
|
||||
if pixel_values_videos is not None:
|
||||
batch_size, frames, channels, height, width = pixel_values_videos.shape
|
||||
pixel_values_videos = pixel_values_videos.view(batch_size * frames, channels, height, width)
|
||||
video_features = self.vision_tower(pixel_values_videos, output_hidden_states=True)
|
||||
selected_video_feature = video_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_video_feature = selected_video_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_video_feature = selected_video_feature
|
||||
video_features = self.multi_modal_projector(selected_video_feature)
|
||||
|
||||
video_features = self.apply_pooling(video_features)
|
||||
video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1)
|
||||
image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device)
|
||||
video_features = self.get_video_features(
|
||||
pixel_values_videos,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
image_newline = (
|
||||
self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device)
|
||||
)
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
|
||||
|
@ -392,6 +392,22 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -477,10 +493,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -23,7 +23,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -355,41 +355,59 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
|
||||
return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
|
||||
|
||||
def _get_vision_features(
|
||||
self,
|
||||
pixel_values_images: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
if pixel_values_images is None and pixel_values_videos is None:
|
||||
raise ValueError("You have to specify `pixel_values_images` or `pixel_values_videos`")
|
||||
def get_image_features(
|
||||
self, pixel_values_images: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
|
||||
if pixel_values_videos is not None:
|
||||
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
||||
Args:
|
||||
pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
|
||||
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
||||
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
|
||||
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
|
||||
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
|
||||
image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
image_outputs = image_outputs[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
image_outputs = image_outputs
|
||||
else:
|
||||
video_outputs = None
|
||||
num_frames = 0
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
if pixel_values_images is not None:
|
||||
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
|
||||
image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1)
|
||||
image_features = self.multi_modal_projector(image_outputs)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
image_outputs = image_outputs[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
image_outputs = image_outputs
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
else:
|
||||
image_outputs = None
|
||||
return image_features
|
||||
|
||||
return image_outputs, video_outputs, num_frames
|
||||
def get_video_features(self, pixel_values_videos: torch.FloatTensor, vision_feature_layer: int):
|
||||
"""
|
||||
Obtains video last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
|
||||
The tensors corresponding to the input videos.
|
||||
vision_feature_layer (`int`):
|
||||
The index of the layer to select the vision feature.
|
||||
Returns:
|
||||
video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`).
|
||||
frames (`int`): Number of frames the videos have.
|
||||
"""
|
||||
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
||||
|
||||
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
||||
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
|
||||
video_features = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
|
||||
video_features = self.multi_modal_projector(video_features)
|
||||
|
||||
return video_features, num_frames
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@ -534,110 +552,106 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
)
|
||||
legacy_processing = inputs_not_expanded or pixels_present
|
||||
|
||||
if pixel_values_images is not None or pixel_values_videos is not None:
|
||||
image_outputs, video_outputs, num_frames = self._get_vision_features(
|
||||
pixel_values_images=pixel_values_images,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_features = None
|
||||
if pixel_values_images is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values_images,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
image_features = video_features = None
|
||||
if image_outputs is not None:
|
||||
image_features = self.multi_modal_projector(image_outputs)
|
||||
if video_outputs is not None:
|
||||
video_features = self.multi_modal_projector(video_outputs)
|
||||
video_features = None
|
||||
if pixel_values_videos is not None:
|
||||
video_features, num_frames = self.get_video_features(
|
||||
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
for features, frames in ((image_features, 1), (video_features, num_frames)):
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
position_ids,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_visual_features(
|
||||
features,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
labels,
|
||||
num_frames=frames,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
for features, frames in ((image_features, 1), (video_features, num_frames)):
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
labels,
|
||||
position_ids,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_visual_features(
|
||||
features,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
labels,
|
||||
num_frames=frames,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
if image_outputs is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if pixel_values_images is not None:
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||
n_image_features = image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
if video_outputs is not None:
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
|
||||
n_video_features = video_features.shape[1]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
|
||||
n_video_features = video_features.shape[1]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -275,6 +275,17 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
|
||||
# Ignore copy
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layers (`List[int]`):
|
||||
The list og indexes of the layers to select the vision feature.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
# For VIP-llava, the image features are computed this way
|
||||
|
Loading…
Reference in New Issue
Block a user