mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
ac28a23b3d
commit
b1c914e463
@ -521,7 +521,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
@ -898,7 +898,10 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
@ -980,13 +980,19 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
@ -485,13 +485,19 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
||||
else:
|
||||
if image_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
if video_features is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
@ -1690,14 +1690,24 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
|
@ -621,14 +621,20 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
else:
|
||||
if image_outputs is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
if video_outputs is not None:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.video_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
|
||||
|
@ -514,7 +514,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
# TODO: @raushan retain only the new behavior after v4.47
|
||||
else:
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
Loading…
Reference in New Issue
Block a user