[feat] LlavaNext add feature size check to avoid CUDA Runtime Error (#33608)

* [feat] add feature size check to avoid CUDA Runtime Error

* [minor] add error handling to all llava models

* [minor] avoid nested if else

* [minor] add error message to Qwen2-vl and chameleon

* [fix] token dimension for check

* [minor] add feature dim check for videos too

* [fix] dimension check

* [fix] test reference values

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
This commit is contained in:
laurentd-lunit 2024-10-15 23:19:18 +09:00 committed by GitHub
parent d00f1ca860
commit 0f49deacbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 88 additions and 7 deletions

View File

@ -1287,6 +1287,12 @@ class ChameleonModel(ChameleonPreTrainedModel):
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

View File

@ -518,6 +518,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)

View File

@ -895,6 +895,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)

View File

@ -967,6 +967,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
@ -976,6 +982,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)

View File

@ -482,6 +482,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
@ -491,6 +497,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)

View File

@ -619,7 +619,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
@ -647,7 +652,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_video_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)

View File

@ -1710,6 +1710,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
@ -1722,6 +1728,12 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)

View File

@ -618,6 +618,12 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_outputs is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
@ -626,8 +632,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_outputs is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
n_video_features = video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)

View File

@ -511,6 +511,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)

View File

@ -118,8 +118,8 @@ class LlavaVisionText2TextModelTester:
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
self.num_image_tokens = 224
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.seq_length = seq_length + self.num_image_tokens
def get_config(self):

View File

@ -111,8 +111,8 @@ class VipLlavaVisionText2TextModelTester:
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
self.num_image_tokens = 224
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.seq_length = seq_length + self.num_image_tokens
def get_config(self):