mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
parent
a769ed45e1
commit
fe76b60370
@ -472,6 +472,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
|||||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||||
|
|
||||||
|
image_features = None
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_features = self.get_image_features(
|
image_features = self.get_image_features(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@ -479,69 +480,67 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
|||||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_processing:
|
if legacy_processing:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
"Expanding inputs for image tokens in LLaVa should be done in processing. "
|
||||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||||
|
)
|
||||||
|
# prefill stage vs decoding stage (legacy behavior copied)
|
||||||
|
if input_ids.shape[1] != 1:
|
||||||
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
# prefill stage vs decoding stage (legacy behavior copied)
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||||
if input_ids.shape[1] != 1:
|
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
||||||
)
|
|
||||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
|
||||||
else:
|
|
||||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
||||||
# that are set to 0
|
|
||||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
||||||
|
|
||||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
|
||||||
|
|
||||||
# Get the target length
|
|
||||||
target_length = input_ids.shape[1]
|
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
|
||||||
(attention_mask.shape[0], past_length),
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=attention_mask.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out only the tokens that can be un-attended, this can happen
|
|
||||||
# if one uses Llava + Fused modules where the cache on the
|
|
||||||
# first iteration is already big enough, or if one passes custom cache
|
|
||||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
|
||||||
new_batch_index = batch_index[valid_indices]
|
|
||||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
|
||||||
|
|
||||||
# Zero-out the places where we don't need to attend
|
|
||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
|
||||||
|
|
||||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
||||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
|
||||||
-target_length:
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: @raushan retain only the new behavior after v4.47
|
|
||||||
else:
|
else:
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||||
n_image_features = image_features.shape[1]
|
# that are set to 0
|
||||||
if n_image_tokens != n_image_features:
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
special_image_mask = (
|
|
||||||
(input_ids == self.config.image_token_index)
|
# Get the target length
|
||||||
.unsqueeze(-1)
|
target_length = input_ids.shape[1]
|
||||||
.expand_as(inputs_embeds)
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
.to(inputs_embeds.device)
|
|
||||||
|
extended_attention_mask = torch.ones(
|
||||||
|
(attention_mask.shape[0], past_length),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
# Filter out only the tokens that can be un-attended, this can happen
|
||||||
|
# if one uses Llava + Fused modules where the cache on the
|
||||||
|
# first iteration is already big enough, or if one passes custom cache
|
||||||
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||||
|
new_batch_index = batch_index[valid_indices]
|
||||||
|
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||||
|
|
||||||
|
# Zero-out the places where we don't need to attend
|
||||||
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||||
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||||
|
|
||||||
|
# TODO: @raushan retain only the new behavior after v4.47
|
||||||
|
elif image_features is not None:
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||||
|
n_image_features = image_features.shape[1]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
special_image_mask = (
|
||||||
|
(input_ids == self.config.image_token_index)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -602,12 +601,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
|||||||
):
|
):
|
||||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||||
|
|
||||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
|
||||||
legacy_processing = (
|
|
||||||
input_ids is not None
|
|
||||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -618,7 +611,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_processing or cache_position[0] == 0:
|
if cache_position[0] == 0:
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
# Otherwise we need pixel values to be passed to model
|
# Otherwise we need pixel values to be passed to model
|
||||||
model_inputs["pixel_values"] = pixel_values
|
model_inputs["pixel_values"] = pixel_values
|
||||||
|
@ -846,6 +846,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
|||||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||||
|
|
||||||
|
image_features = None
|
||||||
if pixel_values is not None and pixel_values.size(0) > 0:
|
if pixel_values is not None and pixel_values.size(0) > 0:
|
||||||
image_features = self.get_image_features(
|
image_features = self.get_image_features(
|
||||||
pixel_values,
|
pixel_values,
|
||||||
@ -861,74 +862,73 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
|||||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
image_newline=self.image_newline,
|
image_newline=self.image_newline,
|
||||||
)
|
)
|
||||||
if legacy_processing:
|
|
||||||
logger.warning_once(
|
if legacy_processing:
|
||||||
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
logger.warning_once(
|
||||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
|
||||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||||
|
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||||
|
)
|
||||||
|
if input_ids.shape[1] != 1:
|
||||||
|
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
||||||
|
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
||||||
|
image_features,
|
||||||
|
feature_lens,
|
||||||
|
inputs_embeds,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
labels=labels,
|
||||||
)
|
)
|
||||||
if input_ids.shape[1] != 1:
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||||
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
|
||||||
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
|
|
||||||
image_features,
|
|
||||||
feature_lens,
|
|
||||||
inputs_embeds,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
labels=labels,
|
|
||||||
)
|
|
||||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
|
||||||
else:
|
|
||||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
||||||
# that are set to 0
|
|
||||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
||||||
|
|
||||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
|
||||||
|
|
||||||
# Get the target length
|
|
||||||
target_length = input_ids.shape[1]
|
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
|
||||||
(attention_mask.shape[0], past_length),
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=attention_mask.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out only the tokens that can be un-attended, this can happen
|
|
||||||
# if one uses Llava + Fused modules where the cache on the
|
|
||||||
# first iteration is already big enough, or if one passes custom cache
|
|
||||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
|
||||||
new_batch_index = batch_index[valid_indices]
|
|
||||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
|
||||||
|
|
||||||
# Zero-out the places where we don't need to attend
|
|
||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
|
||||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
||||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
|
||||||
-target_length:
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: @raushan retain only the new behavior after v4.47
|
|
||||||
else:
|
else:
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||||
n_image_features = image_features.shape[0]
|
# that are set to 0
|
||||||
if n_image_tokens != n_image_features:
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
special_image_mask = (
|
|
||||||
(input_ids == self.config.image_token_index)
|
# Get the target length
|
||||||
.unsqueeze(-1)
|
target_length = input_ids.shape[1]
|
||||||
.expand_as(inputs_embeds)
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
.to(inputs_embeds.device)
|
|
||||||
|
extended_attention_mask = torch.ones(
|
||||||
|
(attention_mask.shape[0], past_length),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
# Filter out only the tokens that can be un-attended, this can happen
|
||||||
|
# if one uses Llava + Fused modules where the cache on the
|
||||||
|
# first iteration is already big enough, or if one passes custom cache
|
||||||
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||||
|
new_batch_index = batch_index[valid_indices]
|
||||||
|
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||||
|
|
||||||
|
# Zero-out the places where we don't need to attend
|
||||||
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||||
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||||
|
|
||||||
|
# TODO: @raushan retain only the new behavior after v4.47
|
||||||
|
elif image_features is not None:
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||||
|
n_image_features = image_features.shape[0]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
special_image_mask = (
|
||||||
|
(input_ids == self.config.image_token_index)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -990,11 +990,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
|||||||
):
|
):
|
||||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||||
|
|
||||||
legacy_processing = (
|
|
||||||
input_ids is not None
|
|
||||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -1007,7 +1002,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
|||||||
|
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
# Otherwise we need pixel values to be passed to model
|
# Otherwise we need pixel values to be passed to model
|
||||||
if legacy_processing or cache_position[0] == 0:
|
if cache_position[0] == 0:
|
||||||
model_inputs["pixel_values"] = pixel_values
|
model_inputs["pixel_values"] = pixel_values
|
||||||
model_inputs["image_sizes"] = image_sizes
|
model_inputs["image_sizes"] = image_sizes
|
||||||
|
|
||||||
|
@ -1110,17 +1110,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
):
|
):
|
||||||
# Overwritten -- extra custom processing
|
# Overwritten -- extra custom processing
|
||||||
|
|
||||||
if input_ids is not None:
|
|
||||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
|
||||||
1
|
|
||||||
).max() < self.config.image_seq_length
|
|
||||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
|
||||||
1
|
|
||||||
).max() < self.config.video_seq_length
|
|
||||||
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
|
|
||||||
video_token_not_enough and pixel_values_videos is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -1133,7 +1122,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
|
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
# Otherwise we need pixel values to be passed to model
|
# Otherwise we need pixel values to be passed to model
|
||||||
if legacy_processing or cache_position[0] == 0:
|
if cache_position[0] == 0:
|
||||||
model_inputs["pixel_values"] = pixel_values
|
model_inputs["pixel_values"] = pixel_values
|
||||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||||
model_inputs["image_sizes"] = image_sizes
|
model_inputs["image_sizes"] = image_sizes
|
||||||
|
@ -623,17 +623,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
|||||||
):
|
):
|
||||||
# Overwritten -- extra custom processing
|
# Overwritten -- extra custom processing
|
||||||
|
|
||||||
if input_ids is not None:
|
|
||||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
|
||||||
1
|
|
||||||
).max() < self.config.image_seq_length
|
|
||||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
|
||||||
1
|
|
||||||
).max() < self.config.video_seq_length
|
|
||||||
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
|
|
||||||
video_token_not_enough and pixel_values_videos is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -646,7 +635,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
|||||||
|
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
# Otherwise we need pixel values to be passed to model
|
# Otherwise we need pixel values to be passed to model
|
||||||
if legacy_processing or cache_position[0] == 0:
|
if cache_position[0] == 0:
|
||||||
model_inputs["pixel_values"] = pixel_values
|
model_inputs["pixel_values"] = pixel_values
|
||||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||||
model_inputs["image_sizes"] = image_sizes
|
model_inputs["image_sizes"] = image_sizes
|
||||||
|
@ -720,17 +720,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
|||||||
):
|
):
|
||||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||||
|
|
||||||
if input_ids is not None:
|
|
||||||
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
|
|
||||||
1
|
|
||||||
).max() < self.config.image_seq_length
|
|
||||||
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
|
|
||||||
1
|
|
||||||
).max() < self.config.video_seq_length
|
|
||||||
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
|
|
||||||
video_token_not_enough and pixel_values_videos is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -741,7 +730,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_processing or cache_position[0] == 0:
|
if cache_position[0] == 0:
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
# Otherwise we need pixel values to be passed to model
|
# Otherwise we need pixel values to be passed to model
|
||||||
model_inputs["pixel_values_images"] = pixel_values_images
|
model_inputs["pixel_values_images"] = pixel_values_images
|
||||||
|
@ -466,72 +466,71 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
|||||||
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
||||||
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
|
||||||
|
|
||||||
|
image_features = None
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_features = self.get_image_features(
|
image_features = self.get_image_features(
|
||||||
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_processing:
|
if legacy_processing:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Expanding inputs for image tokens in VipLLaVa should be done in processing. "
|
"Expanding inputs for image tokens in VipLLaVa should be done in processing. "
|
||||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
|
"Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
|
||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||||
|
)
|
||||||
|
# prefill stage vs decoding stage (legacy behavior copied)
|
||||||
|
if input_ids.shape[1] != 1:
|
||||||
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||||
)
|
)
|
||||||
# prefill stage vs decoding stage (legacy behavior copied)
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
||||||
if input_ids.shape[1] != 1:
|
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
||||||
)
|
|
||||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
|
||||||
else:
|
|
||||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
|
||||||
# that are set to 0
|
|
||||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
|
||||||
|
|
||||||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
|
||||||
|
|
||||||
target_length = input_ids.shape[1]
|
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
|
||||||
(attention_mask.shape[0], past_length),
|
|
||||||
dtype=attention_mask.dtype,
|
|
||||||
device=attention_mask.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out only the tokens that can be un-attended, this can happen
|
|
||||||
# in the case one uses Llava + Fused modules where the cache on the
|
|
||||||
# first iteration is already big enough, or if one passes custom cache
|
|
||||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
|
||||||
new_batch_index = batch_index[valid_indices]
|
|
||||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
|
||||||
|
|
||||||
# Zero-out the places where we don't need to attend
|
|
||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
|
||||||
|
|
||||||
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
|
||||||
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
|
|
||||||
-target_length:
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO: @raushan retain only the new behavior after v4.47
|
|
||||||
else:
|
else:
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||||
n_image_features = image_features.shape[1]
|
# that are set to 0
|
||||||
if n_image_tokens != n_image_features:
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
special_image_mask = (
|
|
||||||
(input_ids == self.config.image_token_index)
|
target_length = input_ids.shape[1]
|
||||||
.unsqueeze(-1)
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
extended_attention_mask = torch.ones(
|
||||||
|
(attention_mask.shape[0], past_length),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
# Filter out only the tokens that can be un-attended, this can happen
|
||||||
|
# in the case one uses Llava + Fused modules where the cache on the
|
||||||
|
# first iteration is already big enough, or if one passes custom cache
|
||||||
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||||
|
new_batch_index = batch_index[valid_indices]
|
||||||
|
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||||
|
|
||||||
|
# Zero-out the places where we don't need to attend
|
||||||
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||||
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
|
||||||
|
|
||||||
|
# TODO: @raushan retain only the new behavior after v4.47
|
||||||
|
elif image_features is not None:
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
|
||||||
|
n_image_features = image_features.shape[1]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
special_image_mask = (
|
||||||
|
(input_ids == self.config.image_token_index)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@ -590,12 +589,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
|||||||
):
|
):
|
||||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||||
|
|
||||||
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
|
|
||||||
legacy_processing = (
|
|
||||||
input_ids is not None
|
|
||||||
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@ -606,7 +599,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if legacy_processing or cache_position[0] == 0:
|
if cache_position[0] == 0:
|
||||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||||
# Otherwise we need pixel values to be passed to model
|
# Otherwise we need pixel values to be passed to model
|
||||||
model_inputs["pixel_values"] = pixel_values
|
model_inputs["pixel_values"] = pixel_values
|
||||||
|
Loading…
Reference in New Issue
Block a user