mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
VLM: fixes after refactor (#32907)
* leave only half of the changes * fix tests * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava * fix tests, first try * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava * fix, second try * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava * fix * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava
This commit is contained in:
parent
f24f084329
commit
7d2d6ce9cb
@ -476,6 +476,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
@ -506,6 +507,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
@ -585,9 +589,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
elif cache_position[0] == 0:
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
@ -136,6 +136,7 @@ class LlavaProcessor(ProcessorMixin):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
# try to expand inputs in processing if we have the necessary parts
|
||||
prompt_strings = text
|
||||
if image_inputs.get("pixel_values") is not None:
|
||||
if self.patch_size is not None and self.vision_feature_select_strategy is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
@ -150,7 +151,6 @@ class LlavaProcessor(ProcessorMixin):
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
else:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
|
@ -848,6 +848,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
position_ids,
|
||||
labels=labels,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
@ -877,6 +878,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
@ -956,12 +960,9 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
elif cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
||||
|
@ -140,30 +140,29 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
if self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# cannot infer image expansion length if no images are found
|
||||
elif not image_inputs:
|
||||
prompt_strings = text
|
||||
else:
|
||||
image_sizes = image_inputs["image_sizes"]
|
||||
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
|
||||
prompt_strings = []
|
||||
for image_size, sample in zip(image_sizes, text):
|
||||
# Replace the image token with the expanded image token sequence
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
prompt_strings = text
|
||||
if image_inputs:
|
||||
if self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
else:
|
||||
image_sizes = iter(image_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
while self.image_token in sample:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
|
||||
prompt_strings.append(sample)
|
||||
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt_strings,
|
||||
|
@ -29,7 +29,6 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
image_size_to_num_patches,
|
||||
)
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@ -389,13 +388,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
|
||||
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
|
||||
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
|
||||
inputs_expanded = (
|
||||
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None
|
||||
legacy_processing = inputs_expanded or pixels_present
|
||||
pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None)
|
||||
legacy_processing = inputs_not_expanded or pixels_present
|
||||
|
||||
image_features = feature_lens = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
@ -414,75 +417,76 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
iterator = (
|
||||
(image_features, feature_lens, self.config.image_token_index),
|
||||
(video_features, video_feature_lens, self.config.video_token_index),
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
iterator = (
|
||||
(image_features, feature_lens, self.config.image_token_index),
|
||||
(video_features, video_feature_lens, self.config.video_token_index),
|
||||
)
|
||||
for features, lens, special_token in zip(iterator):
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
features,
|
||||
lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=special_token,
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
for features, lens, special_token in iterator:
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
features,
|
||||
lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=special_token,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -493,6 +497,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
@ -534,58 +539,34 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
pixel_values_videos=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
if input_ids is not None:
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"image_sizes": image_sizes,
|
||||
}
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
||||
return model_inputs
|
||||
|
@ -31,7 +31,6 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...utils import (
|
||||
@ -767,6 +766,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
|
||||
r"""
|
||||
@ -874,13 +874,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
|
||||
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
|
||||
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
|
||||
inputs_expanded = (
|
||||
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None
|
||||
legacy_processing = inputs_expanded or pixels_present
|
||||
pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None)
|
||||
legacy_processing = inputs_not_expanded or pixels_present
|
||||
|
||||
image_features = feature_lens = None
|
||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||
@ -899,75 +903,76 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
if legacy_processing:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
iterator = (
|
||||
(image_features, feature_lens, self.config.image_token_index),
|
||||
(video_features, video_feature_lens, self.config.video_token_index),
|
||||
)
|
||||
if input_ids.shape[1] != 1:
|
||||
iterator = (
|
||||
(image_features, feature_lens, self.config.image_token_index),
|
||||
(video_features, video_feature_lens, self.config.video_token_index),
|
||||
)
|
||||
for features, lens, special_token in iterator:
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
features,
|
||||
lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=special_token,
|
||||
)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
for features, lens, special_token in iterator:
|
||||
if features is not None:
|
||||
(
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels,
|
||||
input_ids,
|
||||
) = self._merge_input_ids_with_image_features(
|
||||
features,
|
||||
lens,
|
||||
inputs_embeds,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
labels=labels,
|
||||
image_token_index=special_token,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
# Get the target length
|
||||
target_length = input_ids.shape[1]
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
@ -978,6 +983,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
|
||||
@ -1020,64 +1026,38 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel):
|
||||
pixel_values_videos=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
if input_ids is not None:
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
if "num_logits_to_keep" != None:
|
||||
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"image_sizes": image_sizes,
|
||||
}
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
model_inputs["image_sizes"] = image_sizes
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _get_image_features(self, pixel_values, image_sizes):
|
||||
|
@ -19,6 +19,7 @@ Processor class for LLaVa-NeXT-Video.
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
@ -160,35 +161,29 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
print(self.patch_size, self.vision_feature_select_strategy, image_inputs, videos_inputs.keys())
|
||||
|
||||
if self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
prompt_strings = text
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
# cannot infer image expansion length if no images/videos are found
|
||||
elif not image_inputs and not videos_inputs:
|
||||
prompt_strings = text
|
||||
else:
|
||||
# images expand taking into account num_of_patches in each image
|
||||
if image_inputs:
|
||||
image_sizes = image_inputs["image_sizes"]
|
||||
image_sizes = iter(image_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
|
||||
prompt_strings = []
|
||||
for image_size, sample in zip(image_sizes, text):
|
||||
# Replace the image token with the expanded image token sequence
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
for sample in text:
|
||||
while self.image_token in sample:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
|
||||
prompt_strings.append(sample)
|
||||
text = prompt_strings
|
||||
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||
|
||||
# videos are easier, simply get frames and multiply
|
||||
if videos_inputs:
|
||||
@ -197,23 +192,65 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
||||
num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.video_token, self.video_token * num_video_tokens)
|
||||
prompt_strings.append(sample)
|
||||
text = prompt_strings
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt_strings,
|
||||
text,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
print(text_inputs.keys())
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features
|
||||
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
||||
image_grid_pinpoints = self.image_processor.image_grid_pinpoints
|
||||
|
||||
height_best_resolution, width_best_resolution = select_best_resolution(
|
||||
[orig_height, orig_width], image_grid_pinpoints
|
||||
)
|
||||
scale_height, scale_width = height_best_resolution // height, width_best_resolution // width
|
||||
|
||||
patches_height = height // self.patch_size
|
||||
patches_width = width // self.patch_size
|
||||
unpadded_features, newline_features = self._get_unpadded_features(
|
||||
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
|
||||
)
|
||||
# The base patch covers the entire image (+1 for the CLS)
|
||||
base_features = patches_height * patches_width + 1
|
||||
num_image_tokens = unpadded_features + newline_features + base_features
|
||||
return num_image_tokens
|
||||
|
||||
# Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_unpadded_features
|
||||
def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width):
|
||||
"""
|
||||
Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA
|
||||
because it divided each image into patches depending on its resolution. Therefore we need to calculate how many
|
||||
patches an image is divided into and get the number of features from that.
|
||||
"""
|
||||
current_height = patches_height * scale_height
|
||||
current_width = patches_width * scale_width
|
||||
|
||||
original_aspect_ratio = width / height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= padding * 2
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= padding * 2
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -529,15 +529,19 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
|
||||
# if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing
|
||||
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
|
||||
img_token_count = (input_ids == self.config.image_token_index).sum(1).max()
|
||||
video_token_count = (input_ids == self.config.video_token_index).sum(1).max()
|
||||
inputs_expanded = (
|
||||
img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
pixels_present = (
|
||||
input_ids.shape[-1] == 1 and pixel_values_images is not None and pixel_values_videos is not None
|
||||
pixels_present = input_ids.shape[-1] == 1 and (
|
||||
pixel_values_images is not None or pixel_values_videos is not None
|
||||
)
|
||||
legacy_processing = inputs_expanded or pixels_present
|
||||
legacy_processing = inputs_not_expanded or pixels_present
|
||||
|
||||
if pixel_values_images is not None or pixel_values_videos is not None:
|
||||
image_outputs, video_outputs, num_frames = self._get_vision_features(
|
||||
@ -577,6 +581,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
labels,
|
||||
num_frames=frames,
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
@ -606,6 +611,9 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
@ -678,11 +686,16 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
||||
legacy_processing = input_ids is not None and (
|
||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||
and (input_ids == self.config.video_token_index).sum(1).max() < self.config.video_seq_length
|
||||
)
|
||||
if input_ids is not None:
|
||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
||||
1
|
||||
).max() < self.config.image_seq_length
|
||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
||||
1
|
||||
).max() < self.config.video_seq_length
|
||||
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
|
||||
video_token_not_enough and pixel_values_videos is not None
|
||||
)
|
||||
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
@ -694,11 +707,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values_images"] = pixel_values_images
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
|
||||
elif cache_position[0] == 0:
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values_images"] = pixel_values_images
|
||||
|
@ -145,24 +145,28 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
if encoded_images is not None and self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
prompt_strings = text
|
||||
prompt_strings = text
|
||||
if encoded_images is not None and (self.patch_size is None or self.vision_feature_select_strategy is None):
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in Video-LLaVa should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
|
||||
)
|
||||
# Replace the image/video tokens with the expanded token sequence
|
||||
elif encoded_images is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
if "pixel_values" in encoded_images:
|
||||
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values")[0]))
|
||||
if "pixel_values_images" in encoded_images.keys():
|
||||
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
else:
|
||||
|
||||
if "pixel_values_videos" in encoded_images.keys():
|
||||
one_video = to_numpy_array(encoded_images.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
num_video_tokens = num_image_tokens * num_frames
|
||||
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
num_video_tokens = num_image_tokens * num_frames
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
|
@ -471,6 +471,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
)
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||
else:
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
@ -500,6 +501,9 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
|
||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
||||
-target_length:
|
||||
]
|
||||
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
@ -579,9 +583,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if legacy_processing:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
elif cache_position[0] == 0:
|
||||
if legacy_processing or cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
@ -302,7 +302,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
||||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Lastly, be respectful of the environment and other visitors, as the pier is a shared space where people can enjoy the view, relax, or engage in recreational activities." # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
processor.decode(output[0], skip_special_tokens=True),
|
||||
@ -353,7 +353,10 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.',
|
||||
'USER: \nWhat is this?\nASSISTANT: Cats'
|
||||
] # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
@ -393,7 +396,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_batched_generation(self):
|
||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device)
|
||||
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
@ -415,9 +418,9 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
model = model.eval()
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog in one and a ll",
|
||||
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
|
||||
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
|
||||
"\n \nUSER: What's the the difference of two images?\nASSISTANT: The difference between the two images is that one shows a dog standing on a grassy field, while",
|
||||
"\nUSER: Describe the image.\nASSISTANT: The image features a brown and white dog sitting on a sidewalk. The dog is holding a small",
|
||||
"\nUSER: Describe the image.\nASSISTANT: The image features a lone llama standing on a grassy hill. The llama is the",
|
||||
]
|
||||
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=20)
|
||||
@ -451,26 +454,23 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def test_llava_merge_inputs_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||
model_id = "llava-hf/llava-1.5-7b-hf"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
|
||||
# Simulate some user inputs
|
||||
pixel_values = torch.randn(
|
||||
(2, 3, 336, 336),
|
||||
(1, 3, 336, 336),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
|
||||
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
@ -515,6 +515,31 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_generation_siglip_backbone(self):
|
||||
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16", device_map=torch_device)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
# check processing with expansion of inputs (w/o expansion should work with any backbone)
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
inputs = processor(
|
||||
text="<|im_start|>user\n<image>\nWhat are these?<|im_end|>\n<|im_start|>assistant",
|
||||
images=raw_image,
|
||||
return_tensors="pt",
|
||||
).to(torch_device, torch.float16)
|
||||
|
||||
# Make sure that `generate` works
|
||||
output = model.generate(**inputs, max_new_tokens=30)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat"
|
||||
self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
|
@ -363,11 +363,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
output = model(**inputs)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[-4.7695, -4.5664, -0.2786],
|
||||
[-10.6250, -10.8906, -2.5254],
|
||||
[-6.7383, -7.2461, -0.6787],
|
||||
],
|
||||
[[-4.7695, -4.5664, -0.2788], [-10.6172, -10.8828, -2.5273], [-6.7383, -7.2422, -0.6694]],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
@ -471,16 +467,16 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
output = model(**inputs)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]],
|
||||
[[-0.1287, -0.1294, -0.1284], [-0.2744, -0.2698, -0.2671], [-0.1071, -0.1091, -0.1056]],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3)
|
||||
assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device))
|
||||
assert torch.allclose(output.loss, torch.tensor(7.0206, device=torch_device), atol=1e-3)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given' # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
@ -534,38 +530,66 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# model is in eval mode by default so we should get pad on the left side
|
||||
# we can check the first hidden-states (aka inputs embeds)
|
||||
# the first element was lo-res image and we expect the first 1414 tokens to be all pads
|
||||
output_eval = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item())
|
||||
|
||||
# otherwise padding is on the right side, so it's last 1414 tokens
|
||||
self.processor.padding_side = "right"
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
|
||||
).to(torch_device)
|
||||
|
||||
model.train()
|
||||
# the first element was lo-res image and we expect the first 732 tokens to be all pads
|
||||
with torch.no_grad():
|
||||
output_train = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())
|
||||
output_eval = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_eval.hidden_states[0][0, :732, ...] == 0).all().item())
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as logs:
|
||||
model.padding_side = "left"
|
||||
model.train()
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
with torch.no_grad():
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
|
||||
self.assertIn(
|
||||
"Padding side is set to 'left' but the model is in training mode. For training", logs.output[0]
|
||||
)
|
||||
self.assertIn("Padding side is set to 'left' but the model is in training mode. For training", logs)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as logs:
|
||||
model.padding_side = "right"
|
||||
model.eval()
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
with torch.no_grad():
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
|
||||
self.assertIn(
|
||||
"Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0]
|
||||
)
|
||||
self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing_multiimage(self):
|
||||
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image><image>\nDescribe the similarity between the two images:\nASSISTANT:"
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
deer_image = Image.open(
|
||||
requests.get(
|
||||
"https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e",
|
||||
stream=True,
|
||||
).raw
|
||||
)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3969)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 23)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
|
@ -18,6 +18,7 @@ import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import (
|
||||
@ -363,29 +364,6 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
inputs = self.processor(self.prompt_video, videos=self.video, return_tensors="pt")
|
||||
expected_input_ids = [
|
||||
1,
|
||||
3148,
|
||||
1001,
|
||||
29901,
|
||||
29871,
|
||||
32000,
|
||||
13,
|
||||
11008,
|
||||
338,
|
||||
445,
|
||||
4863,
|
||||
2090,
|
||||
1460,
|
||||
29973,
|
||||
319,
|
||||
1799,
|
||||
9047,
|
||||
13566,
|
||||
29901,
|
||||
]
|
||||
self.assertListEqual(expected_input_ids, inputs.input_ids[0].tolist())
|
||||
|
||||
# verify single forward pass
|
||||
inputs = inputs.to(torch_device)
|
||||
with torch.no_grad():
|
||||
@ -393,7 +371,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=40)
|
||||
EXPECTED_DECODED_TEXT = 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the book. The child appears to be reading a book, but instead of a calm and focused reading experience' # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems' # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
@ -416,7 +394,10 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the', 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the'] # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a',
|
||||
'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a'
|
||||
] # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
@ -447,7 +428,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
|
||||
EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a benchmark test for a machine learning model. It shows the performance of various models on a task, with the x-axis representing the number of parameters (measured in millions) and the y' # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"' # fmt: skip
|
||||
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@ -493,41 +474,25 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# model is in eval mode by default so we should get pad on the left side
|
||||
# we can check the first hidden-states (aka inputs embeds)
|
||||
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
|
||||
output_eval = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
|
||||
|
||||
# otherwise padding is on the right side, so it's last 1482 tokens
|
||||
self.processor.padding_side = "right"
|
||||
inputs_batched = self.processor(
|
||||
[self.prompt_video, self.prompt_image],
|
||||
images=[self.image],
|
||||
videos=[self.video],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(torch_device)
|
||||
|
||||
model.train()
|
||||
with torch.no_grad():
|
||||
output_train = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())
|
||||
output_eval = model(**inputs_batched, output_hidden_states=True)
|
||||
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as logs:
|
||||
model.padding_side = "left"
|
||||
model.train()
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
with torch.no_grad():
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
|
||||
self.assertIn(
|
||||
"Padding side is set to 'left' but the model is in training mode. For training", logs.output[0]
|
||||
)
|
||||
self.assertIn("Padding side is set to 'left' but the model is in training mode. For training", logs)
|
||||
|
||||
with self.assertLogs("transformers", level="WARNING") as logs:
|
||||
model.padding_side = "right"
|
||||
model.eval()
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
with torch.no_grad():
|
||||
model(**inputs_batched, output_hidden_states=True)
|
||||
|
||||
self.assertIn(
|
||||
"Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0]
|
||||
)
|
||||
self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
@ -556,3 +521,73 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing_images(self):
|
||||
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
|
||||
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2652)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing_multiimage(self):
|
||||
model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
|
||||
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image><image>\nDescribe the similarity between the two images:\nASSISTANT:"
|
||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||
deer_image = Image.open(
|
||||
requests.get(
|
||||
"https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e",
|
||||
stream=True,
|
||||
).raw
|
||||
)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3968)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 22)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
@ -383,18 +383,19 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
|
||||
|
||||
prompt = "USER: <video>Why is this video funny? ASSISTANT:"
|
||||
prompt = "USER: <video>\nWhy is this video funny? ASSISTANT:"
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
)
|
||||
video_file = np.load(video_file)
|
||||
inputs = self.processor(prompt, videos=video_file, return_tensors="pt")
|
||||
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[1, 3148, 1001, 29901, 29871, 32001, 3750, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901]]) # fmt: skip
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[1, 3148, 1001, 29901, 29871, 32001, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901]]) # fmt: skip
|
||||
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "USER: Why is this video funny? ASSISTANT: The video is funny because the baby is playing with a Wii remote while sitting on a bed" # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = "USER: \nWhy is this video funny? ASSISTANT: The video is funny because it shows a baby sitting on a bed and reading a book, which" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
@ -404,12 +405,11 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_mixed_inputs(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
|
||||
|
||||
prompts = [
|
||||
"USER: <image>What are the cats in the image doing? ASSISTANT:",
|
||||
"USER: <video>Why is this video funny? ASSISTANT:",
|
||||
"USER: <image>\nWhat are the cats in the image doing? ASSISTANT:",
|
||||
"USER: <video>\nWhy is this video funny? ASSISTANT:",
|
||||
]
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
@ -422,8 +422,8 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'USER: What are the cats in the image doing? ASSISTANT: The cats in the image are lying down on a red couch, possibly sleeping or rest',
|
||||
'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is playing with a Wii remote while sitting on a bed'
|
||||
'USER: \nWhat are the cats in the image doing? ASSISTANT: The cats in the image are sleeping or resting on a couch.',
|
||||
'USER: \nWhy is this video funny? ASSISTANT: The video is funny because it shows a baby sitting on a bed and reading a book. The'
|
||||
] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
@ -434,12 +434,10 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_llama(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
|
||||
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
|
||||
|
||||
prompt = "USER: <video>Describe the video in details. ASSISTANT:"
|
||||
prompt = "USER: <video>\nDescribe the video in details. ASSISTANT:"
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
)
|
||||
@ -447,11 +445,11 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
||||
EXPECTED_DECODED_TEXT = "USER: Describe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \
|
||||
"The child appears to be enjoying the book, as they are fully engaged in the reading process. The bed is located in a bedroom, and there is a chair nearby. " \
|
||||
"The child is wearing a light blue shirt and pink pants, and they have glasses on. The room is well-lit, and there is a clock on the wall. The child seems " \
|
||||
"to be in a comfortable and relaxed environment, which is conducive to reading and learning. Overall, the video captures a heartwarming moment of a child " \
|
||||
"engaging in a simple yet essential activity, which is reading." # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = "USER: \nDescribe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \
|
||||
"The child appears to be enjoying the book, as they are fully engaged in the activity. The bed is located in a bedroom, and there is a chair nearby. The " \
|
||||
"child is wearing a blue shirt and glasses, which suggests that they might have a visual impairment. The room is well-lit, and there is a clock on the wall, " \
|
||||
"indicating the time. The child's focus on the book indicates that they are interested in the content and are actively participating in the reading process. " \
|
||||
"Overall, the video captures a heartwarming moment of a child engaging in a simple yet essential activity, which is reading." # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
processor.decode(output[0], skip_special_tokens=True),
|
||||
@ -461,15 +459,13 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_llama_batched(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
|
||||
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf")
|
||||
processor.tokenizer.padding_side = "left"
|
||||
|
||||
prompts = [
|
||||
"USER: <video>What is the baby doing? ASSISTANT:",
|
||||
"USER: <video>Who is sitting next to the woman? ASSISTANT:",
|
||||
"USER: <video>\nWhat is the baby doing? ASSISTANT:",
|
||||
"USER: <video>\nWho is sitting next to the woman? ASSISTANT:",
|
||||
]
|
||||
video_1 = np.load(
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
|
||||
@ -483,48 +479,12 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'USER: What is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.Ъ',
|
||||
'USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman.Ъ'
|
||||
'USER: \nWhat is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.',
|
||||
'USER: \nWho is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman.'
|
||||
] # fmt: skip
|
||||
|
||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_llama_batched_regression(self):
|
||||
# Let' s make sure we test the preprocessing to replace what is used
|
||||
|
||||
# Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before)
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained(
|
||||
"LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True, attn_implementation="eager"
|
||||
)
|
||||
processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", pad_token="<pad>")
|
||||
processor.tokenizer.padding_side = "left"
|
||||
|
||||
prompts = [
|
||||
"USER: <video>What is the baby doing? ASSISTANT:",
|
||||
"USER: <video>Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman. USER: <video>What about this video? ASSITANT:",
|
||||
]
|
||||
video_1 = np.load(
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
|
||||
)
|
||||
video_2 = np.load(
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
|
||||
)
|
||||
|
||||
inputs = processor(prompts, videos=[video_1, video_2, video_1], return_tensors="pt", padding=True)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'USER: What is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.Ъ',
|
||||
'USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman. USER: What about this video? ASSITANT: The video shows a baby sitting on a bed, reading a book. The baby is wearing glass'
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_video_llava_index_error_bug(self):
|
||||
@ -552,32 +512,23 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
def test_video_llava_merge_inputs_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained(
|
||||
"LanguageBind/Video-LLaVA-7B-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True)
|
||||
|
||||
# Simulate some user inputs
|
||||
pixel_values_videos = torch.randn(
|
||||
(2, 8, 3, 224, 224),
|
||||
(1, 8, 3, 224, 224),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: off
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
32001, 32001, 1, 15043, 7084, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 29871, 13, 7900
|
||||
],
|
||||
[
|
||||
1, 15043, 7084, 29901, 29871, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 29871, 13, 7900
|
||||
],
|
||||
],
|
||||
[[32002, 32002, 1, 15043, 7084, 32001, 29871, 13, 7900]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
attention_mask = torch.tensor(
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
@ -591,6 +542,36 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing_images(self):
|
||||
model_id = "LanguageBind/Video-LLaVA-7B-hf"
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = VideoLlavaProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <image>\nDescribe the image in details. ASSISTANT:"
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, images=image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 274)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, images=image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20)
|
||||
|
||||
# check that both inputs are handled correctly and generate the same output
|
||||
self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist())
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_expansion_in_processing(self):
|
||||
@ -598,7 +579,7 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
processor = VideoLlavaProcessor.from_pretrained(model_id)
|
||||
|
||||
prompt = "USER: <video>Describe the video in details. ASSISTANT:"
|
||||
prompt = "USER: <video>\nDescribe the video in details. ASSISTANT:"
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
)
|
||||
@ -608,13 +589,13 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2073)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2074)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
inputs = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
# generate exactly 20 tokens
|
||||
output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20)
|
||||
|
@ -271,26 +271,23 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def test_vipllava_merge_inputs_error_bug(self):
|
||||
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
|
||||
model_id = "llava-hf/vip-llava-7b-hf"
|
||||
model = VipLlavaForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
model = VipLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
|
||||
|
||||
# Simulate some user inputs
|
||||
pixel_values = torch.randn(
|
||||
(2, 3, 336, 336),
|
||||
(1, 3, 336, 336),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
|
||||
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1, 1, 1, 1]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user