[VLMs] fix flash-attention tests (#37603)

* fix one test

* fa2 ln test

* remove keys from config recursively

* fix

* fixup
This commit is contained in:
Raushan Turganbay 2025-04-24 11:48:11 +02:00 committed by GitHub
parent 02baa61fab
commit 1cfcbfcab8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 52 additions and 83 deletions

View File

@ -843,29 +843,16 @@ class PretrainedConfig(PushToHubMixin):
):
serializable_config_dict[key] = value
self._remove_keys_not_serialized(serializable_config_dict)
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict)
if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_pp_plan"]
if "_name_or_path" in serializable_config_dict:
del serializable_config_dict["_name_or_path"]
return serializable_config_dict
def to_dict(self) -> dict[str, Any]:
@ -878,18 +865,6 @@ class PretrainedConfig(PushToHubMixin):
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in output:
del output["base_model_pp_plan"]
# Transformers version when serializing the model
output["transformers_version"] = __version__
@ -902,16 +877,14 @@ class PretrainedConfig(PushToHubMixin):
output[key] = value
self._remove_keys_not_serialized(output)
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = output.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(output)
return output
@ -1011,6 +984,33 @@ class PretrainedConfig(PushToHubMixin):
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
"""
Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
Runs recursive check on the dict, to remove from all sub configs.
"""
if hasattr(self, "quantization_config"):
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = d.pop("_pre_quantization_dtype", None)
if "_auto_class" in d:
del d["_auto_class"]
if "_commit_hash" in d:
del d["_commit_hash"]
if "_attn_implementation_internal" in d:
del d["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in d:
del d["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in d:
del d["base_model_pp_plan"]
if "_name_or_path" in d:
del d["_name_or_path"]
for value in d.values():
if isinstance(value, dict):
self._remove_keys_not_serialized(value)
@classmethod
def register_for_auto_class(cls, auto_class="AutoConfig"):
"""

View File

@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth
config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
def _assign_original_dtype(module):
for child in module.children():
if isinstance(child, PreTrainedModel):
child.config._pre_quantization_dtype = original_dtype
_assign_original_dtype(child)
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)
# Prepare the full device map
if device_map is not None:

View File

@ -125,6 +125,9 @@ class InternVLVisionAttention(nn.Module):
proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm
# Needed for flash attention
self.is_causal = False
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
@ -134,9 +137,6 @@ class InternVLVisionAttention(nn.Module):
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
# Needed for flash attention
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,

View File

@ -344,6 +344,7 @@ class JanusVisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm
self.is_causal = False
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1
@ -398,7 +399,7 @@ class JanusVisionAttention(nn.Module):
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=False,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)

View File

@ -509,6 +509,7 @@ class JanusVisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm
self.is_causal = False
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1
@ -563,7 +564,7 @@ class JanusVisionAttention(nn.Module):
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale,
is_causal=False,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)

View File

@ -316,10 +316,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_sdpa_can_compile_dynamic(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
# todo: yoni - fix or improve the test
@unittest.skip("Difference is slightly higher than the threshold")
def test_batching_equivalence(self):

View File

@ -222,8 +222,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
@unittest.skip(
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
)
def test_past_key_values_format(self):
pass

View File

@ -407,10 +407,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
@require_torch_sdpa
@slow

View File

@ -367,10 +367,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
@require_torch_sdpa
@slow

View File

@ -302,10 +302,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)

View File

@ -331,10 +331,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)

View File

@ -302,10 +302,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)

View File

@ -334,10 +334,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)

View File

@ -331,10 +331,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)

View File

@ -378,10 +378,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
def test_generate_methods_with_logits_to_keep(self):
super().test_generate_methods_with_logits_to_keep()
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip
def test_training_gradient_checkpointing(self):
pass

View File

@ -225,10 +225,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)

View File

@ -305,10 +305,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)