mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[compile
] re-enable for Qwen-VL models (#38127)
* compile qwen models * delete TODO comment * fix embeds test * fix assisted decoding * add comments
This commit is contained in:
parent
4542086db7
commit
a21f11fca2
@ -40,7 +40,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
|
|||||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
|
||||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
||||||
|
|
||||||
|
|
||||||
@ -358,7 +358,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.get_text_config().initializer_range
|
std = self.config.get_text_config().initializer_range
|
||||||
@ -1659,9 +1659,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||||
n_image_features = image_embeds.shape[0]
|
n_image_features = image_embeds.shape[0]
|
||||||
if n_image_tokens != n_image_features:
|
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
)
|
)
|
||||||
@ -1676,9 +1676,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||||
n_video_features = video_embeds.shape[0]
|
n_video_features = video_embeds.shape[0]
|
||||||
if n_video_tokens != n_video_features:
|
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
)
|
)
|
||||||
@ -1694,20 +1694,32 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
if position_ids is None:
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
attention_mask_2d = attention_mask
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
if attention_mask is not None and attention_mask.ndim == 4:
|
||||||
if (
|
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||||
|
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||||
|
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||||
|
|
||||||
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
|
# models currently cannot do asssisted decoding
|
||||||
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||||
|
(input_ids is not None and input_ids.shape[1] != 1)
|
||||||
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||||
|
)
|
||||||
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||||
):
|
)
|
||||||
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
input_ids,
|
input_ids,
|
||||||
image_grid_thw,
|
image_grid_thw,
|
||||||
video_grid_thw,
|
video_grid_thw,
|
||||||
second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
attention_mask,
|
attention_mask=attention_mask_2d,
|
||||||
)
|
)
|
||||||
self.rope_deltas = rope_deltas
|
self.rope_deltas = rope_deltas
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
@ -1747,6 +1759,61 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return output if return_dict else output.to_tuple()
|
return output if return_dict else output.to_tuple()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||||
|
`(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||||
|
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
|
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
|
||||||
@ -2108,60 +2175,5 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||||||
|
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = torch.full(
|
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
|
||||||
)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
|
||||||
causal_mask.device
|
|
||||||
)
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]
|
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]
|
||||||
|
@ -50,7 +50,7 @@ from ...image_utils import ImageInput
|
|||||||
from ...modeling_flash_attention_utils import is_flash_attn_available
|
from ...modeling_flash_attention_utils import is_flash_attn_available
|
||||||
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import is_torchdynamo_compiling, logging
|
||||||
from ...video_utils import VideoInput
|
from ...video_utils import VideoInput
|
||||||
|
|
||||||
|
|
||||||
@ -647,9 +647,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||||
n_image_features = image_embeds.shape[0]
|
n_image_features = image_embeds.shape[0]
|
||||||
if n_image_tokens != n_image_features:
|
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
)
|
)
|
||||||
@ -664,9 +664,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||||
n_video_features = video_embeds.shape[0]
|
n_video_features = video_embeds.shape[0]
|
||||||
if n_video_tokens != n_video_features:
|
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
)
|
)
|
||||||
@ -682,20 +682,32 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
if position_ids is None:
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
attention_mask_2d = attention_mask
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
if attention_mask is not None and attention_mask.ndim == 4:
|
||||||
if (
|
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||||
|
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||||
|
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||||
|
|
||||||
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
|
# models currently cannot do asssisted decoding
|
||||||
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||||
|
(input_ids is not None and input_ids.shape[1] != 1)
|
||||||
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||||
|
)
|
||||||
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||||
):
|
)
|
||||||
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
input_ids,
|
input_ids,
|
||||||
image_grid_thw,
|
image_grid_thw,
|
||||||
video_grid_thw,
|
video_grid_thw,
|
||||||
second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
attention_mask,
|
attention_mask=attention_mask_2d,
|
||||||
)
|
)
|
||||||
self.rope_deltas = rope_deltas
|
self.rope_deltas = rope_deltas
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
@ -924,7 +924,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.get_text_config().initializer_range
|
std = self.config.get_text_config().initializer_range
|
||||||
@ -1616,16 +1616,28 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
if position_ids is None:
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
attention_mask_2d = attention_mask
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
if attention_mask is not None and attention_mask.ndim == 4:
|
||||||
if (
|
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||||
|
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||||
|
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||||
|
|
||||||
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
|
# models currently cannot do asssisted decoding
|
||||||
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||||
|
(input_ids is not None and input_ids.shape[1] != 1)
|
||||||
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||||
|
)
|
||||||
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||||
):
|
)
|
||||||
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
input_ids, image_grid_thw, video_grid_thw, attention_mask_2d
|
||||||
)
|
)
|
||||||
self.rope_deltas = rope_deltas
|
self.rope_deltas = rope_deltas
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
@ -1662,6 +1674,62 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return output if return_dict else output.to_tuple()
|
return output if return_dict else output.to_tuple()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||||
|
`(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||||
|
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||||
|
causal_mask.device
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||||
_checkpoint_conversion_mapping = {
|
_checkpoint_conversion_mapping = {
|
||||||
@ -1974,61 +2042,5 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
sequence_length: int,
|
|
||||||
target_length: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
batch_size: int,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
|
||||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`torch.Tensor`):
|
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
|
||||||
The sequence length being processed.
|
|
||||||
target_length (`int`):
|
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
|
||||||
The dtype to use for the 4D attention mask.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
|
||||||
batch_size (`torch.Tensor`):
|
|
||||||
Batch size.
|
|
||||||
"""
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
|
||||||
causal_mask = attention_mask
|
|
||||||
else:
|
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = torch.full(
|
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
|
||||||
)
|
|
||||||
if sequence_length != 1:
|
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
||||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
|
||||||
causal_mask.device
|
|
||||||
)
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]
|
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]
|
||||||
|
@ -346,10 +346,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
def test_model_parallelism(self):
|
def test_model_parallelism(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models")
|
|
||||||
def test_sdpa_can_compile_dynamic(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models")
|
@unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models")
|
||||||
def test_sdpa_can_dispatch_on_flash(self):
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
pass
|
pass
|
||||||
@ -368,10 +364,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
|
|
||||||
def test_generate_compile_fullgraph(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
|
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
|
||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
super().test_prompt_lookup_decoding_matches_greedy_search()
|
super().test_prompt_lookup_decoding_matches_greedy_search()
|
||||||
|
@ -300,10 +300,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
def test_model_parallelism(self):
|
def test_model_parallelism(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
|
|
||||||
def test_sdpa_can_compile_dynamic(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
|
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
|
||||||
def test_sdpa_can_dispatch_on_flash(self):
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user