[compile] re-enable for Qwen-VL models (#38127)

* compile qwen models

* delete TODO comment

* fix embeds test

* fix assisted decoding

* add comments
This commit is contained in:
Raushan Turganbay 2025-05-21 11:50:39 +02:00 committed by GitHub
parent 4542086db7
commit a21f11fca2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 182 additions and 158 deletions

View File

@ -40,7 +40,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
@ -358,7 +358,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` _supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.get_text_config().initializer_range std = self.config.get_text_config().initializer_range
@ -1659,9 +1659,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
@ -1676,9 +1676,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
@ -1694,20 +1694,32 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device) attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None:
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): attention_mask_2d = attention_mask
# calculate RoPE index once per generation in the pre-fill stage only if attention_mask is not None and attention_mask.ndim == 4:
if ( attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
attention_mask_2d = (1.0 - attention_mask_2d).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0) (cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0)
): )
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index( position_ids, rope_deltas = self.get_rope_index(
input_ids, input_ids,
image_grid_thw, image_grid_thw,
video_grid_thw, video_grid_thw,
second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
attention_mask, attention_mask=attention_mask_2d,
) )
self.rope_deltas = rope_deltas self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids # then use the prev pre-calculated rope-deltas to get the correct position ids
@ -1747,6 +1759,61 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
) )
return output if return_dict else output.to_tuple() return output if return_dict else output.to_tuple()
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@dataclass @dataclass
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
@ -2108,60 +2175,5 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return input_ids, model_kwargs return input_ids, model_kwargs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"] __all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]

View File

@ -50,7 +50,7 @@ from ...image_utils import ImageInput
from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_flash_attention_utils import is_flash_attn_available
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging from ...utils import is_torchdynamo_compiling, logging
from ...video_utils import VideoInput from ...video_utils import VideoInput
@ -647,9 +647,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0] n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features: if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError( raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
) )
@ -664,9 +664,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if pixel_values_videos is not None: if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_embeds.shape[0] n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features: if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError( raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
) )
@ -682,20 +682,32 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device) attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None:
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): attention_mask_2d = attention_mask
# calculate RoPE index once per generation in the pre-fill stage only if attention_mask is not None and attention_mask.ndim == 4:
if ( attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
attention_mask_2d = (1.0 - attention_mask_2d).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0) (cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0)
): )
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index( position_ids, rope_deltas = self.get_rope_index(
input_ids, input_ids,
image_grid_thw, image_grid_thw,
video_grid_thw, video_grid_thw,
second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
attention_mask, attention_mask=attention_mask_2d,
) )
self.rope_deltas = rope_deltas self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids # then use the prev pre-calculated rope-deltas to get the correct position ids

View File

@ -924,7 +924,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` _supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.get_text_config().initializer_range std = self.config.get_text_config().initializer_range
@ -1616,16 +1616,28 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device) attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None:
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): attention_mask_2d = attention_mask
# calculate RoPE index once per generation in the pre-fill stage only if attention_mask is not None and attention_mask.ndim == 4:
if ( attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
attention_mask_2d = (1.0 - attention_mask_2d).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0) (cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) or (past_key_values is None or past_key_values.get_seq_length() == 0)
): )
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index( position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask input_ids, image_grid_thw, video_grid_thw, attention_mask_2d
) )
self.rope_deltas = rope_deltas self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids # then use the prev pre-calculated rope-deltas to get the correct position ids
@ -1662,6 +1674,62 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
) )
return output if return_dict else output.to_tuple() return output if return_dict else output.to_tuple()
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = { _checkpoint_conversion_mapping = {
@ -1974,61 +2042,5 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return input_ids, model_kwargs return input_ids, model_kwargs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"] __all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]

View File

@ -346,10 +346,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_model_parallelism(self): def test_model_parallelism(self):
pass pass
@unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models") @unittest.skip(reason="Compile not yet supported because in Qwen2_5_VL models")
def test_sdpa_can_dispatch_on_flash(self): def test_sdpa_can_dispatch_on_flash(self):
pass pass
@ -368,10 +364,6 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
def test_generate_compile_fullgraph(self):
pass
@is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
super().test_prompt_lookup_decoding_matches_greedy_search() super().test_prompt_lookup_decoding_matches_greedy_search()

View File

@ -300,10 +300,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_model_parallelism(self): def test_model_parallelism(self):
pass pass
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip(reason="Compile not yet supported because in Qwen2VL models") @unittest.skip(reason="Compile not yet supported because in Qwen2VL models")
def test_sdpa_can_dispatch_on_flash(self): def test_sdpa_can_dispatch_on_flash(self):
pass pass