VLM: fixes after refactor (#32907)

* leave only half of the changes

* fix tests

* [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava

* fix tests, first try

* [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava

* fix, second try

* [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava

* fix

* [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava
This commit is contained in:
Raushan Turganbay 2024-09-10 12:02:37 +02:00 committed by GitHub
parent f24f084329
commit 7d2d6ce9cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 577 additions and 500 deletions

View File

@ -476,6 +476,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
@ -506,6 +507,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else:
@ -585,9 +589,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
**kwargs,
)
if legacy_processing:
model_inputs["pixel_values"] = pixel_values
elif cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values

View File

@ -136,6 +136,7 @@ class LlavaProcessor(ProcessorMixin):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
if self.patch_size is not None and self.vision_feature_select_strategy is not None:
# Replace the image token with the expanded image token sequence
@ -150,7 +151,6 @@ class LlavaProcessor(ProcessorMixin):
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
else:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "

View File

@ -848,6 +848,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
@ -877,6 +878,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else:
@ -956,12 +960,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
**kwargs,
)
if legacy_processing:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes
elif cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes

View File

@ -140,30 +140,29 @@ class LlavaNextProcessor(ProcessorMixin):
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if self.patch_size is None or self.vision_feature_select_strategy is None:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# cannot infer image expansion length if no images are found
elif not image_inputs:
prompt_strings = text
else:
image_sizes = image_inputs["image_sizes"]
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for image_size, sample in zip(image_sizes, text):
# Replace the image token with the expanded image token sequence
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
prompt_strings = text
if image_inputs:
if self.patch_size is None or self.vision_feature_select_strategy is None:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
else:
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for sample in text:
while self.image_token in sample:
image_size = next(image_sizes)
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
text_inputs = self.tokenizer(
prompt_strings,

View File

@ -29,7 +29,6 @@ from transformers.models.llava_next.modeling_llava_next import (
image_size_to_num_patches,
)
from ...cache_utils import Cache
from ...utils import (
logging,
replace_return_docstrings,
@ -389,13 +388,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
inputs_expanded = (
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None
legacy_processing = inputs_expanded or pixels_present
pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None)
legacy_processing = inputs_not_expanded or pixels_present
image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
@ -414,75 +417,76 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
if legacy_processing:
logger.warning_once(
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
if legacy_processing:
logger.warning_once(
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
if input_ids.shape[1] != 1:
iterator = (
(image_features, feature_lens, self.config.image_token_index),
(video_features, video_feature_lens, self.config.video_token_index),
)
if input_ids.shape[1] != 1:
iterator = (
(image_features, feature_lens, self.config.image_token_index),
(video_features, video_feature_lens, self.config.video_token_index),
)
for features, lens, special_token in zip(iterator):
if features is not None:
(
inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
features,
lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
image_token_index=special_token,
)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
# TODO: @raushan retain only the new behavior after v4.47
for features, lens, special_token in iterator:
if features is not None:
(
inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
features,
lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
image_token_index=special_token,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
@ -493,6 +497,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs[0]
@ -534,58 +539,34 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
pixel_values_videos=None,
image_sizes=None,
attention_mask=None,
cache_position=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_sizes": image_sizes,
}
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
**kwargs,
)
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes
return model_inputs

View File

@ -31,7 +31,6 @@ from torch import nn
from ... import PreTrainedModel
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...utils import (
@ -767,6 +766,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
@ -874,13 +874,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
inputs_expanded = (
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None
legacy_processing = inputs_expanded or pixels_present
pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None)
legacy_processing = inputs_not_expanded or pixels_present
image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
@ -899,75 +903,76 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
if legacy_processing:
logger.warning_once(
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
if legacy_processing:
logger.warning_once(
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
if input_ids.shape[1] != 1:
iterator = (
(image_features, feature_lens, self.config.image_token_index),
(video_features, video_feature_lens, self.config.video_token_index),
)
if input_ids.shape[1] != 1:
iterator = (
(image_features, feature_lens, self.config.image_token_index),
(video_features, video_feature_lens, self.config.video_token_index),
)
for features, lens, special_token in iterator:
if features is not None:
(
inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
features,
lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
image_token_index=special_token,
)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
# TODO: @raushan retain only the new behavior after v4.47
for features, lens, special_token in iterator:
if features is not None:
(
inputs_embeds,
attention_mask,
position_ids,
labels,
input_ids,
) = self._merge_input_ids_with_image_features(
features,
lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
image_token_index=special_token,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
@ -978,6 +983,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)
@ -1020,64 +1026,38 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
pixel_values_videos=None,
image_sizes=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if "num_logits_to_keep" != None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_sizes": image_sizes,
}
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes
return model_inputs
def _get_image_features(self, pixel_values, image_sizes):

View File

@ -19,6 +19,7 @@ Processor class for LLaVa-NeXT-Video.
from typing import TYPE_CHECKING, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
@ -160,35 +161,29 @@ class LlavaNextVideoProcessor(ProcessorMixin):
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
print(self.patch_size, self.vision_feature_select_strategy, image_inputs, videos_inputs.keys())
if self.patch_size is None or self.vision_feature_select_strategy is None:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# cannot infer image expansion length if no images/videos are found
elif not image_inputs and not videos_inputs:
prompt_strings = text
else:
# images expand taking into account num_of_patches in each image
if image_inputs:
image_sizes = image_inputs["image_sizes"]
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for image_size, sample in zip(image_sizes, text):
# Replace the image token with the expanded image token sequence
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
for sample in text:
while self.image_token in sample:
image_size = next(image_sizes)
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
text = prompt_strings
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
# videos are easier, simply get frames and multiply
if videos_inputs:
@ -197,23 +192,65 @@ class LlavaNextVideoProcessor(ProcessorMixin):
num_frames = one_video.shape[0] # frame dim is always after batch dim
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer
prompt_strings = []
for sample in text:
sample = sample.replace(self.video_token, self.video_token * num_video_tokens)
prompt_strings.append(sample)
text = prompt_strings
text_inputs = self.tokenizer(
prompt_strings,
text,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
print(text_inputs.keys())
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
height_best_resolution, width_best_resolution = select_best_resolution(
[orig_height, orig_width], image_grid_pinpoints
)
scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
patches_height = height // self.patch_size
patches_width = width // self.patch_size
unpadded_features, newline_features = self._get_unpadded_features(
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
)
# The base patch covers the entire image (+1 for the CLS)
base_features = patches_height * patches_width + 1
num_image_tokens = unpadded_features + newline_features + base_features
return num_image_tokens
# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_unpadded_features
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
"""
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
patches an image is divided into and get the number of features from that.
"""
current_height = patches_height * scale_height
current_width = patches_width * scale_width
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = (width * current_height) // height
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -529,15 +529,19 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
inputs_expanded = (
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
pixels_present = (
input_ids.shape[-1] == 1 and pixel_values_images is not None and pixel_values_videos is not None
pixels_present = input_ids.shape[-1] == 1 and (
pixel_values_images is not None or pixel_values_videos is not None
)
legacy_processing = inputs_expanded or pixels_present
legacy_processing = inputs_not_expanded or pixels_present
if pixel_values_images is not None or pixel_values_videos is not None:
image_outputs, video_outputs, num_frames = self._get_vision_features(
@ -577,6 +581,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
labels,
num_frames=frames,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
@ -606,6 +611,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else:
@ -678,11 +686,16 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
num_logits_to_keep=None,
**kwargs,
):
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
legacy_processing = input_ids is not None and (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
and (input_ids == self.config.video_token_index).sum(1).max() < self.config.video_seq_length
)
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
@ -694,11 +707,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
**kwargs,
)
if legacy_processing:
model_inputs["pixel_values_images"] = pixel_values_images
model_inputs["pixel_values_videos"] = pixel_values_videos
elif cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values_images"] = pixel_values_images

View File

@ -145,24 +145,28 @@ class VideoLlavaProcessor(ProcessorMixin):
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if encoded_images is not None and self.patch_size is None or self.vision_feature_select_strategy is None:
prompt_strings = text
prompt_strings = text
if encoded_images is not None and (self.patch_size is None or self.vision_feature_select_strategy is None):
logger.warning_once(
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
)
# Replace the image/video tokens with the expanded token sequence
elif encoded_images is not None:
# Replace the image token with the expanded image token sequence
if "pixel_values" in encoded_images:
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values")[0]))
if "pixel_values_images" in encoded_images.keys():
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0]))
num_frames = 1
else:
if "pixel_values_videos" in encoded_images.keys():
one_video = to_numpy_array(encoded_images.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
num_video_tokens = num_image_tokens * num_frames
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
num_video_tokens = num_image_tokens * num_frames
if self.vision_feature_select_strategy == "default":

View File

@ -471,6 +471,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
@ -500,6 +501,9 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else:
@ -579,9 +583,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
**kwargs,
)
if legacy_processing:
model_inputs["pixel_values"] = pixel_values
elif cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values

View File

@ -302,7 +302,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Lastly, be respectful of the environment and other visitors, as the pier is a shared space where people can enjoy the view, relax, or engage in recreational activities." # fmt: skip
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
@ -353,7 +353,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
EXPECTED_DECODED_TEXT = [
'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.',
'USER: \nWhat is this?\nASSISTANT: Cats'
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
@ -393,7 +396,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
@require_torch
@require_vision
def test_batched_generation(self):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device)
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
@ -415,9 +418,9 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
model = model.eval()
EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog in one and a ll",
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
"\n \nUSER: What's the the difference of two images?\nASSISTANT: The difference between the two images is that one shows a dog standing on a grassy field, while",
"\nUSER: Describe the image.\nASSISTANT: The image features a brown and white dog sitting on a sidewalk. The dog is holding a small",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone llama standing on a grassy hill. The llama is the",
]
generate_ids = model.generate(**inputs, max_new_tokens=20)
@ -451,26 +454,23 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
def test_llava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
(1, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
@ -515,6 +515,31 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_bitsandbytes
def test_generation_siglip_backbone(self):
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16", device_map=torch_device)
processor = AutoProcessor.from_pretrained(model_id)
# check processing with expansion of inputs (w/o expansion should work with any backbone)
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(
text="<|im_start|>user\n<image>\nWhat are these?<|im_end|>\n<|im_start|>assistant",
images=raw_image,
return_tensors="pt",
).to(torch_device, torch.float16)
# Make sure that `generate` works
output = model.generate(**inputs, max_new_tokens=30)
EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat"
self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_expansion_in_processing(self):

View File

@ -363,11 +363,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model(**inputs)
expected_slice = torch.tensor(
[
[-4.7695, -4.5664, -0.2786],
[-10.6250, -10.8906, -2.5254],
[-6.7383, -7.2461, -0.6787],
],
[[-4.7695, -4.5664, -0.2788], [-10.6172, -10.8828, -2.5273], [-6.7383, -7.2422, -0.6694]],
dtype=torch.float32,
device=torch_device,
)
@ -471,16 +467,16 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model(**inputs)
expected_slice = torch.tensor(
[[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]],
[[-0.1287, -0.1294, -0.1284], [-0.2744, -0.2698, -0.2671], [-0.1071, -0.1091, -0.1056]],
dtype=torch.float32,
device=torch_device,
)
assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3)
assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device))
assert torch.allclose(output.loss, torch.tensor(7.0206, device=torch_device), atol=1e-3)
# verify generation
output = model.generate(**inputs, max_new_tokens=50)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
@ -534,38 +530,66 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1414 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item())
# otherwise padding is on the right side, so it's last 1414 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)
model.train()
# the first element was lo-res image and we expect the first 732 tokens to be all pads
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :732, ...] == 0).all().item())
with self.assertLogs("transformers", level="WARNING") as logs:
model.padding_side = "left"
model.train()
model(**inputs_batched, output_hidden_states=True)
with torch.no_grad():
model(**inputs_batched, output_hidden_states=True)
self.assertIn(
"Padding side is set to 'left' but the model is in training mode. For training", logs.output[0]
)
self.assertIn("Padding side is set to 'left' but the model is in training mode. For training", logs)
with self.assertLogs("transformers", level="WARNING") as logs:
model.padding_side = "right"
model.eval()
model(**inputs_batched, output_hidden_states=True)
with torch.no_grad():
model(**inputs_batched, output_hidden_states=True)
self.assertIn(
"Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0]
)
self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs)
@slow
@require_bitsandbytes
def test_expansion_in_processing_multiimage(self):
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image><image>\nDescribe the similarity between the two images:\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
deer_image = Image.open(
requests.get(
"https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e",
stream=True,
).raw
)
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
torch_device, torch.float16
)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3969)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
torch_device, torch.float16
)
self.assertTrue(inputs.input_ids.shape[-1] == 23)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes

View File

@ -18,6 +18,7 @@ import gc
import unittest
import numpy as np
import requests
from huggingface_hub import hf_hub_download
from transformers import (
@ -363,29 +364,6 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
)
inputs = self.processor(self.prompt_video, videos=self.video, return_tensors="pt")
expected_input_ids = [
1,
3148,
1001,
29901,
29871,
32000,
13,
11008,
338,
445,
4863,
2090,
1460,
29973,
319,
1799,
9047,
13566,
29901,
]
self.assertListEqual(expected_input_ids, inputs.input_ids[0].tolist())
# verify single forward pass
inputs = inputs.to(torch_device)
with torch.no_grad():
@ -393,7 +371,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
# verify generation
output = model.generate(**inputs, do_sample=False, max_new_tokens=40)
EXPECTED_DECODED_TEXT = 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the book. The child appears to be reading a book, but instead of a calm and focused reading experience' # fmt: skip
EXPECTED_DECODED_TEXT = 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
@ -416,7 +394,10 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the', 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the'] # fmt: skip
EXPECTED_DECODED_TEXT = [
'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a',
'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a'
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
@ -447,7 +428,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
# verify generation
output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a benchmark test for a machine learning model. It shows the performance of various models on a task, with the x-axis representing the number of parameters (measured in millions) and the y' # fmt: skip
EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"' # fmt: skip
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@ -493,41 +474,25 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
# otherwise padding is on the right side, so it's last 1482 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)
model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
with self.assertLogs("transformers", level="WARNING") as logs:
model.padding_side = "left"
model.train()
model(**inputs_batched, output_hidden_states=True)
with torch.no_grad():
model(**inputs_batched, output_hidden_states=True)
self.assertIn(
"Padding side is set to 'left' but the model is in training mode. For training", logs.output[0]
)
self.assertIn("Padding side is set to 'left' but the model is in training mode. For training", logs)
with self.assertLogs("transformers", level="WARNING") as logs:
model.padding_side = "right"
model.eval()
model(**inputs_batched, output_hidden_states=True)
with torch.no_grad():
model(**inputs_batched, output_hidden_states=True)
self.assertIn(
"Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0]
)
self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs)
@slow
@require_bitsandbytes
@ -556,3 +521,73 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_expansion_in_processing_images(self):
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id)
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2652)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device)
self.assertTrue(inputs.input_ids.shape[-1] == 19)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_expansion_in_processing_multiimage(self):
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
)
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image><image>\nDescribe the similarity between the two images:\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
deer_image = Image.open(
requests.get(
"https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e",
stream=True,
).raw
)
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
torch_device, torch.float16
)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3968)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
torch_device, torch.float16
)
self.assertTrue(inputs.input_ids.shape[-1] == 22)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())

View File

@ -383,18 +383,19 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
# Let' s make sure we test the preprocessing to replace what is used
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
prompt = "USER: <video>Why is this video funny? ASSISTANT:"
prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
video_file = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
)
video_file = np.load(video_file)
inputs = self.processor(prompt, videos=video_file, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[1, 3148, 1001, 29901, 29871, 32001, 3750, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901]]) # fmt: skip
EXPECTED_INPUT_IDS = torch.tensor([[1, 3148, 1001, 29901, 29871, 32001, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "USER: Why is this video funny? ASSISTANT: The video is funny because the baby is playing with a Wii remote while sitting on a bed" # fmt: skip
EXPECTED_DECODED_TEXT = "USER: \nWhy is this video funny? ASSISTANT: The video is funny because it shows a baby sitting on a bed and reading a book, which" # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
@ -404,12 +405,11 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
@slow
@require_bitsandbytes
def test_small_model_integration_test_mixed_inputs(self):
# Let' s make sure we test the preprocessing to replace what is used
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
prompts = [
"USER: <image>What are the cats in the image doing? ASSISTANT:",
"USER: <video>Why is this video funny? ASSISTANT:",
"USER: <image>\nWhat are the cats in the image doing? ASSISTANT:",
"USER: <video>\nWhy is this video funny? ASSISTANT:",
]
video_file = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
@ -422,8 +422,8 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
EXPECTED_DECODED_TEXT = [
'USER: What are the cats in the image doing? ASSISTANT: The cats in the image are lying down on a red couch, possibly sleeping or rest',
'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is playing with a Wii remote while sitting on a bed'
'USER: \nWhat are the cats in the image doing? ASSISTANT: The cats in the image are sleeping or resting on a couch.',
'USER: \nWhy is this video funny? ASSISTANT: The video is funny because it shows a baby sitting on a bed and reading a book. The'
] # fmt: skip
self.assertEqual(
@ -434,12 +434,10 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama(self):
# Let' s make sure we test the preprocessing to replace what is used
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
prompt = "USER: <video>Describe the video in details. ASSISTANT:"
prompt = "USER: <video>\nDescribe the video in details. ASSISTANT:"
video_file = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
)
@ -447,11 +445,11 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: Describe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \
"The child appears to be enjoying the book, as they are fully engaged in the reading process. The bed is located in a bedroom, and there is a chair nearby. " \
"The child is wearing a light blue shirt and pink pants, and they have glasses on. The room is well-lit, and there is a clock on the wall. The child seems " \
"to be in a comfortable and relaxed environment, which is conducive to reading and learning. Overall, the video captures a heartwarming moment of a child " \
"engaging in a simple yet essential activity, which is reading." # fmt: skip
EXPECTED_DECODED_TEXT = "USER: \nDescribe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \
"The child appears to be enjoying the book, as they are fully engaged in the activity. The bed is located in a bedroom, and there is a chair nearby. The " \
"child is wearing a blue shirt and glasses, which suggests that they might have a visual impairment. The room is well-lit, and there is a clock on the wall, " \
"indicating the time. The child's focus on the book indicates that they are interested in the content and are actively participating in the reading process. " \
"Overall, the video captures a heartwarming moment of a child engaging in a simple yet essential activity, which is reading." # fmt: skip
self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
@ -461,15 +459,13 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
processor.tokenizer.padding_side = "left"
prompts = [
"USER: <video>What is the baby doing? ASSISTANT:",
"USER: <video>Who is sitting next to the woman? ASSISTANT:",
"USER: <video>\nWhat is the baby doing? ASSISTANT:",
"USER: <video>\nWho is sitting next to the woman? ASSISTANT:",
]
video_1 = np.load(
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
@ -483,48 +479,12 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = [
'USER: What is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.Ъ',
'USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman.Ъ'
'USER: \nWhat is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.',
'USER: \nWho is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman.'
] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched_regression(self):
# Let' s make sure we test the preprocessing to replace what is used
# Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before)
model = VideoLlavaForConditionalGeneration.from_pretrained(
"LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True, attn_implementation="eager"
)
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", pad_token="<pad>")
processor.tokenizer.padding_side = "left"
prompts = [
"USER: <video>What is the baby doing? ASSISTANT:",
"USER: <video>Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman. USER: <video>What about this video? ASSITANT:",
]
video_1 = np.load(
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
)
video_2 = np.load(
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
)
inputs = processor(prompts, videos=[video_1, video_2, video_1], return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
# fmt: off
EXPECTED_DECODED_TEXT = [
'USER: What is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.Ъ',
'USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman. USER: What about this video? ASSITANT: The video shows a baby sitting on a bed, reading a book. The baby is wearing glass'
]
# fmt: on
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_video_llava_index_error_bug(self):
@ -552,32 +512,23 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
@require_torch_gpu
def test_video_llava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model = VideoLlavaForConditionalGeneration.from_pretrained(
"LanguageBind/Video-LLaVA-7B-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
# Simulate some user inputs
pixel_values_videos = torch.randn(
(2, 8, 3, 224, 224),
(1, 8, 3, 224, 224),
dtype=torch.float,
device=torch_device,
)
# fmt: off
input_ids = torch.tensor(
[
[
32001, 32001, 1, 15043, 7084, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 29871, 13, 7900
],
[
1, 15043, 7084, 29901, 29871, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 29871, 13, 7900
],
],
[[32002, 32002, 1, 15043, 7084, 32001, 29871, 13, 7900]],
dtype=torch.long,
device=torch_device,
)
# fmt: on
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
@ -591,6 +542,36 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
).loss
loss.backward()
@slow
@require_bitsandbytes
def test_expansion_in_processing_images(self):
model_id = "LanguageBind/Video-LLaVA-7B-hf"
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = VideoLlavaProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nDescribe the image in details. ASSISTANT:"
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(prompt, images=image, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 274)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(prompt, images=image, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs.input_ids.shape[-1] == 19)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
# check that both inputs are handled correctly and generate the same output
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
@slow
@require_bitsandbytes
def test_expansion_in_processing(self):
@ -598,7 +579,7 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = VideoLlavaProcessor.from_pretrained(model_id)
prompt = "USER: <video>Describe the video in details. ASSISTANT:"
prompt = "USER: <video>\nDescribe the video in details. ASSISTANT:"
video_file = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
)
@ -608,13 +589,13 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2073)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2074)
# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs.input_ids.shape[-1] == 18)
self.assertTrue(inputs.input_ids.shape[-1] == 19)
# generate exactly 20 tokens
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)

View File

@ -271,26 +271,23 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
def test_vipllava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/vip-llava-7b-hf"
model = VipLlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)
model = VipLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
(1, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)