qwen2.5vl: fix bugs when using flash2+bf16 or num_return_sequences>1 (#36083)

* qwen2.5vl: fix bugs when using flash2+bf16 or num_return_sequences>1

* fix

* fix

* fix

* fix

* add tests

* fix test bugs

* fix

* fix failed tests

* fix
This commit is contained in:
gewenbin0992 2025-02-13 18:35:28 +08:00 committed by GitHub
parent d419862889
commit 6a1ab634b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 299 additions and 7 deletions

View File

@ -2824,8 +2824,12 @@ class GenerationMixin:
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
# input_ids is required for expanding visual inputs in qwen2vl
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
input_ids=input_ids,
expand_size=top_k,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
past_key_values = model_kwargs.get("past_key_values")

View File

@ -26,7 +26,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -452,7 +452,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
@ -1459,7 +1459,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
def __init__(self, config):
super().__init__(config)
@ -1933,5 +1933,127 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
)
return model_inputs
def _get_image_nums_and_video_nums(
self,
input_ids: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Returns:
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
vision_start_mask = input_ids == vision_start_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
return image_nums, video_nums
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
# Overwritten -- Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)
if expand_size == 1:
return input_ids, model_kwargs
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result
for key in dict_to_expand:
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw, list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
# get the num of images for each sample
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "second_per_grid_ts":
if not isinstance(dict_to_expand[key], list):
raise TypeError(
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
)
tensor = torch.tensor(dict_to_expand[key])
lengths = list(video_nums)
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
dict_to_expand[key] = tensor.tolist()
return dict_to_expand
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
# input_ids is required for expanding visual inputs
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"]

View File

@ -312,7 +312,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
@ -382,7 +382,7 @@ class Qwen2_5_VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast):
class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
def __init__(self, config):
super().__init__(config)

View File

@ -21,7 +21,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -1796,5 +1796,127 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
)
return model_inputs
def _get_image_nums_and_video_nums(
self,
input_ids: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Returns:
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
vision_start_mask = input_ids == vision_start_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
return image_nums, video_nums
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
# Overwritten -- Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)
if expand_size == 1:
return input_ids, model_kwargs
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result
for key in dict_to_expand:
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw, list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
# get the num of images for each sample
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "second_per_grid_ts":
if not isinstance(dict_to_expand[key], list):
raise TypeError(
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
)
tensor = torch.tensor(dict_to_expand[key])
lengths = list(video_nums)
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
dict_to_expand[key] = tensor.tolist()
return dict_to_expand
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
# input_ids is required for expanding visual inputs
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]

View File

@ -171,7 +171,9 @@ class Qwen2_5_VLVisionText2TextModelTester:
input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.video_token_id] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length),
dtype=torch.long,
@ -426,6 +428,26 @@ class Qwen2_5_VLIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_expand(self):
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text], images=[self.image], return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, num_return_sequences=3)
EXPECTED_DECODED_TEXT = [
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_batch_wo_image(self):
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(

View File

@ -167,7 +167,9 @@ class Qwen2VLVisionText2TextModelTester:
input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.video_token_id] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
input_ids[:, self.num_image_tokens] = self.image_token_id
input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length),
dtype=torch.long,
@ -435,6 +437,26 @@ class Qwen2VLIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_expand(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text], images=[self.image], return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, num_return_sequences=3)
EXPECTED_DECODED_TEXT = [
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices',
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
def test_small_model_integration_test_batch_wo_image(self):
model = Qwen2VLForConditionalGeneration.from_pretrained(