Fix device mismatch errors (#33851)

fix device mismatch errors
This commit is contained in:
Raushan Turganbay 2024-10-01 15:55:57 +02:00 committed by GitHub
parent ac28a23b3d
commit b1c914e463
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 48 additions and 11 deletions

View File

@ -521,7 +521,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
# TODO: @raushan retain only the new behavior after v4.47
else:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -898,7 +898,10 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
# TODO: @raushan retain only the new behavior after v4.47
else:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

View File

@ -980,13 +980,19 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

View File

@ -485,13 +485,19 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

View File

@ -1690,14 +1690,24 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

View File

@ -621,14 +621,20 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
else:
if image_outputs is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_outputs is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

View File

@ -514,7 +514,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
# TODO: @raushan retain only the new behavior after v4.47
else:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)