LLaVA: latency issues (#34460)

* fix llavas

* code style

* green ci
This commit is contained in:
Raushan Turganbay 2024-10-29 07:54:51 +01:00 committed by GitHub
parent a769ed45e1
commit fe76b60370
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 187 additions and 239 deletions

View File

@ -472,6 +472,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
@ -522,12 +523,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
else:
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
@ -602,12 +601,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@ -618,7 +611,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
**kwargs,
)
if legacy_processing or cache_position[0] == 0:
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values

View File

@ -846,6 +846,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
@ -861,6 +862,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
@ -909,12 +911,10 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
else:
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
@ -990,11 +990,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@ -1007,7 +1002,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes

View File

@ -1110,17 +1110,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
):
# Overwritten -- extra custom processing
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@ -1133,7 +1122,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes

View File

@ -623,17 +623,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
):
# Overwritten -- extra custom processing
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@ -646,7 +635,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["image_sizes"] = image_sizes

View File

@ -720,17 +720,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
if input_ids is not None:
img_token_not_enough = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
video_token_not_enough = (input_ids == self.config.video_token_index).sum(
1
).max() < self.config.video_seq_length
legacy_processing = (img_token_not_enough and pixel_values_images is not None) or (
video_token_not_enough and pixel_values_videos is not None
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@ -741,7 +730,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
**kwargs,
)
if legacy_processing or cache_position[0] == 0:
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values_images"] = pixel_values_images

View File

@ -466,6 +466,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
@ -512,12 +513,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
# TODO: @raushan retain only the new behavior after v4.47
else:
elif image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
@ -590,12 +589,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
legacy_processing = (
input_ids is not None
and (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@ -606,7 +599,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
**kwargs,
)
if legacy_processing or cache_position[0] == 0:
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values