mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
qwen2.5vl: fix bugs when using flash2+bf16 or num_return_sequences>1 (#36083)
* qwen2.5vl: fix bugs when using flash2+bf16 or num_return_sequences>1 * fix * fix * fix * fix * add tests * fix test bugs * fix * fix failed tests * fix
This commit is contained in:
parent
d419862889
commit
6a1ab634b6
@ -2824,8 +2824,12 @@ class GenerationMixin:
|
||||
|
||||
if not sequential:
|
||||
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
|
||||
# input_ids is required for expanding visual inputs in qwen2vl
|
||||
_, model_kwargs = self._expand_inputs_for_generation(
|
||||
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
input_ids=input_ids,
|
||||
expand_size=top_k,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
past_key_values = model_kwargs.get("past_key_values")
|
||||
|
@ -26,7 +26,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -452,7 +452,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
@ -1459,7 +1459,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
|
||||
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1933,5 +1933,127 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Returns:
|
||||
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
||||
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
||||
"""
|
||||
image_token_id = self.config.image_token_id
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
|
||||
vision_start_mask = input_ids == vision_start_token_id
|
||||
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
|
||||
image_mask = input_ids == image_token_id
|
||||
video_mask = input_ids == video_token_id
|
||||
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
|
||||
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
|
||||
|
||||
return image_nums, video_nums
|
||||
|
||||
def _expand_inputs_for_generation(
|
||||
self,
|
||||
expand_size: int = 1,
|
||||
is_encoder_decoder: bool = False,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
**model_kwargs,
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
# Overwritten -- Support for expanding tensors without a batch size dimension
|
||||
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
|
||||
# pixel_values.shape[0] is sum(seqlen_images for samples)
|
||||
# image_grid_thw.shape[0] is sum(num_images for samples)
|
||||
|
||||
if expand_size == 1:
|
||||
return input_ids, model_kwargs
|
||||
|
||||
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
|
||||
|
||||
def _expand_dict_for_generation_visual(dict_to_expand):
|
||||
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
||||
|
||||
def _repeat_interleave_samples(x, lengths, repeat_times):
|
||||
samples = torch.split(x, lengths)
|
||||
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
||||
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
||||
return result
|
||||
|
||||
for key in dict_to_expand:
|
||||
if key == "pixel_values":
|
||||
# split images into samples
|
||||
samples = torch.split(image_grid_thw, list(image_nums))
|
||||
# compute the sequence length of images for each sample
|
||||
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "image_grid_thw":
|
||||
# get the num of images for each sample
|
||||
lengths = list(image_nums)
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "pixel_values_videos":
|
||||
samples = torch.split(video_grid_thw, list(video_nums))
|
||||
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "video_grid_thw":
|
||||
lengths = list(video_nums)
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "second_per_grid_ts":
|
||||
if not isinstance(dict_to_expand[key], list):
|
||||
raise TypeError(
|
||||
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
|
||||
)
|
||||
tensor = torch.tensor(dict_to_expand[key])
|
||||
lengths = list(video_nums)
|
||||
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
|
||||
dict_to_expand[key] = tensor.tolist()
|
||||
return dict_to_expand
|
||||
|
||||
def _expand_dict_for_generation(dict_to_expand):
|
||||
for key in dict_to_expand:
|
||||
if (
|
||||
key != "cache_position"
|
||||
and dict_to_expand[key] is not None
|
||||
and isinstance(dict_to_expand[key], torch.Tensor)
|
||||
and key not in visual_keys
|
||||
):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
# input_ids is required for expanding visual inputs
|
||||
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
||||
if input_ids is not None and input_ids.numel() != 0:
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
||||
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
||||
|
||||
if is_encoder_decoder:
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"]
|
||||
|
@ -312,7 +312,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
|
||||
The final hidden states of the model.
|
||||
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
@ -382,7 +382,7 @@ class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
config_class = Qwen2_5_VLConfig
|
||||
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -21,7 +21,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -1796,5 +1796,127 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Returns:
|
||||
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
||||
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
||||
"""
|
||||
image_token_id = self.config.image_token_id
|
||||
video_token_id = self.config.video_token_id
|
||||
vision_start_token_id = self.config.vision_start_token_id
|
||||
|
||||
vision_start_mask = input_ids == vision_start_token_id
|
||||
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
|
||||
image_mask = input_ids == image_token_id
|
||||
video_mask = input_ids == video_token_id
|
||||
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
|
||||
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
|
||||
|
||||
return image_nums, video_nums
|
||||
|
||||
def _expand_inputs_for_generation(
|
||||
self,
|
||||
expand_size: int = 1,
|
||||
is_encoder_decoder: bool = False,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
**model_kwargs,
|
||||
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
||||
# Overwritten -- Support for expanding tensors without a batch size dimension
|
||||
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
|
||||
# pixel_values.shape[0] is sum(seqlen_images for samples)
|
||||
# image_grid_thw.shape[0] is sum(num_images for samples)
|
||||
|
||||
if expand_size == 1:
|
||||
return input_ids, model_kwargs
|
||||
|
||||
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
|
||||
|
||||
def _expand_dict_for_generation_visual(dict_to_expand):
|
||||
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
||||
|
||||
def _repeat_interleave_samples(x, lengths, repeat_times):
|
||||
samples = torch.split(x, lengths)
|
||||
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
||||
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
||||
return result
|
||||
|
||||
for key in dict_to_expand:
|
||||
if key == "pixel_values":
|
||||
# split images into samples
|
||||
samples = torch.split(image_grid_thw, list(image_nums))
|
||||
# compute the sequence length of images for each sample
|
||||
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "image_grid_thw":
|
||||
# get the num of images for each sample
|
||||
lengths = list(image_nums)
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "pixel_values_videos":
|
||||
samples = torch.split(video_grid_thw, list(video_nums))
|
||||
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "video_grid_thw":
|
||||
lengths = list(video_nums)
|
||||
dict_to_expand[key] = _repeat_interleave_samples(
|
||||
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
||||
)
|
||||
elif key == "second_per_grid_ts":
|
||||
if not isinstance(dict_to_expand[key], list):
|
||||
raise TypeError(
|
||||
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
|
||||
)
|
||||
tensor = torch.tensor(dict_to_expand[key])
|
||||
lengths = list(video_nums)
|
||||
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
|
||||
dict_to_expand[key] = tensor.tolist()
|
||||
return dict_to_expand
|
||||
|
||||
def _expand_dict_for_generation(dict_to_expand):
|
||||
for key in dict_to_expand:
|
||||
if (
|
||||
key != "cache_position"
|
||||
and dict_to_expand[key] is not None
|
||||
and isinstance(dict_to_expand[key], torch.Tensor)
|
||||
and key not in visual_keys
|
||||
):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
# input_ids is required for expanding visual inputs
|
||||
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
||||
if input_ids is not None and input_ids.numel() != 0:
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
||||
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
||||
|
||||
if is_encoder_decoder:
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]
|
||||
|
@ -171,7 +171,9 @@ class Qwen2_5_VLVisionText2TextModelTester:
|
||||
input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
||||
input_ids[:, self.num_image_tokens] = self.image_token_id
|
||||
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
|
||||
labels = torch.zeros(
|
||||
(self.batch_size, self.seq_length),
|
||||
dtype=torch.long,
|
||||
@ -426,6 +428,26 @@ class Qwen2_5_VLIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_expand(self):
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(text=[text], images=[self.image], return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, num_return_sequences=3)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
|
||||
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
|
||||
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
|
||||
] # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_batch_wo_image(self):
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
|
@ -167,7 +167,9 @@ class Qwen2VLVisionText2TextModelTester:
|
||||
input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
||||
input_ids[:, self.num_image_tokens] = self.image_token_id
|
||||
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
|
||||
labels = torch.zeros(
|
||||
(self.batch_size, self.seq_length),
|
||||
dtype=torch.long,
|
||||
@ -435,6 +437,26 @@ class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_expand(self):
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.processor(text=[text], images=[self.image], return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=30, num_return_sequences=3)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
|
||||
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
|
||||
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
|
||||
] # fmt: skip
|
||||
self.assertEqual(
|
||||
self.processor.batch_decode(output, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_small_model_integration_test_batch_wo_image(self):
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
|
Loading…
Reference in New Issue
Block a user