LLaVA: latency issues (#34460)

* fix llavas

* code style

* green ci
This commit is contained in:
Raushan Turganbay 2024-10-29 07:54:51 +01:00 committed by GitHub
parent a769ed45e1
commit fe76b60370
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 187 additions and 239 deletions

View File

@ -472,6 +472,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None) ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
image_features = None
if pixel_values is not None: if pixel_values is not None:
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values=pixel_values, pixel_values=pixel_values,
@ -479,69 +480,67 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
) )
if legacy_processing: if legacy_processing:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. " "Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47." "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# prefill stage vs decoding stage (legacy behavior copied)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
) )
# prefill stage vs decoding stage (legacy behavior copied) cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else: else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() # Retrieve the first layer to inspect the logits and mask out the hidden states
n_image_features = image_features.shape[1] # that are set to 0
if n_image_tokens != n_image_features: first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
) batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
special_image_mask = (
(input_ids == self.config.image_token_index) # Get the target length
.unsqueeze(-1) target_length = input_ids.shape[1]
.expand_as(inputs_embeds) past_length = first_layer_past_key_value.shape[-1]
.to(inputs_embeds.device)
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
) )
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,
@ -602,12 +601,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
model_inputs = self.language_model.prepare_inputs_for_generation( model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids, input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -618,7 +611,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
**kwargs, **kwargs,
) )
if legacy_processing or cache_position[0] == 0: if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model # Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values

View File

@ -846,6 +846,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None) ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
image_features = None
if pixel_values is not None and pixel_values.size(0) > 0: if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values, pixel_values,
@ -861,74 +862,73 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline, image_newline=self.image_newline,
) )
if legacy_processing:
logger.warning_once( if legacy_processing:
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " logger.warning_once(
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47." "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
if input_ids.shape[1] != 1:
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
) )
if input_ids.shape[1] != 1: cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
inputs_embeds = inputs_embeds.to(image_features.dtype)
inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
image_features,
feature_lens,
inputs_embeds,
input_ids,
attention_mask,
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else: else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item() # Retrieve the first layer to inspect the logits and mask out the hidden states
n_image_features = image_features.shape[0] # that are set to 0
if n_image_tokens != n_image_features: first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
) batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
special_image_mask = (
(input_ids == self.config.image_token_index) # Get the target length
.unsqueeze(-1) target_length = input_ids.shape[1]
.expand_as(inputs_embeds) past_length = first_layer_past_key_value.shape[-1]
.to(inputs_embeds.device)
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
) )
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,
@ -990,11 +990,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
model_inputs = self.language_model.prepare_inputs_for_generation( model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids, input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -1007,7 +1002,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model # Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0: if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes model_inputs["image_sizes"] = image_sizes

View File

@ -1110,17 +1110,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
): ):
# Overwritten -- extra custom processing # Overwritten -- extra custom processing
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation( model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids, input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -1133,7 +1122,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model # Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0: if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes model_inputs["image_sizes"] = image_sizes

View File

@ -623,17 +623,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
): ):
# Overwritten -- extra custom processing # Overwritten -- extra custom processing
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation( model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids, input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -646,7 +635,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model # Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0: if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes model_inputs["image_sizes"] = image_sizes

View File

@ -720,17 +720,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation( model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids, input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -741,7 +730,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
**kwargs, **kwargs,
) )
if legacy_processing or cache_position[0] == 0: if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model # Otherwise we need pixel values to be passed to model
model_inputs["pixel_values_images"] = pixel_values_images model_inputs["pixel_values_images"] = pixel_values_images

View File

@ -466,72 +466,71 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None) ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
image_features = None
if pixel_values is not None: if pixel_values is not None:
image_features = self.get_image_features( image_features = self.get_image_features(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
) )
if legacy_processing: if legacy_processing:
logger.warning_once( logger.warning_once(
"Expanding inputs for image tokens in VipLLaVa should be done in processing. " "Expanding inputs for image tokens in VipLLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's image processing config. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47." "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# prefill stage vs decoding stage (legacy behavior copied)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
) )
# prefill stage vs decoding stage (legacy behavior copied) cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# in the case one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
# TODO: @raushan retain only the new behavior after v4.47
else: else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() # Retrieve the first layer to inspect the logits and mask out the hidden states
n_image_features = image_features.shape[1] # that are set to 0
if n_image_tokens != n_image_features: first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
) batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
special_image_mask = (
(input_ids == self.config.image_token_index) target_length = input_ids.shape[1]
.unsqueeze(-1) past_length = first_layer_past_key_value.shape[-1]
.expand_as(inputs_embeds)
.to(inputs_embeds.device) extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
) )
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Filter out only the tokens that can be un-attended, this can happen
# in the case one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model( outputs = self.language_model(
attention_mask=attention_mask, attention_mask=attention_mask,
@ -590,12 +589,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
): ):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
model_inputs = self.language_model.prepare_inputs_for_generation( model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids, input_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
@ -606,7 +599,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
**kwargs, **kwargs,
) )
if legacy_processing or cache_position[0] == 0: if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model # Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values model_inputs["pixel_values"] = pixel_values