mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[VLMs] fix flash-attention tests (#37603)
* fix one test * fa2 ln test * remove keys from config recursively * fix * fixup
This commit is contained in:
parent
02baa61fab
commit
1cfcbfcab8
@ -843,29 +843,16 @@ class PretrainedConfig(PushToHubMixin):
|
||||
):
|
||||
serializable_config_dict[key] = value
|
||||
|
||||
self._remove_keys_not_serialized(serializable_config_dict)
|
||||
|
||||
if hasattr(self, "quantization_config"):
|
||||
serializable_config_dict["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||
|
||||
if "_attn_implementation_internal" in serializable_config_dict:
|
||||
del serializable_config_dict["_attn_implementation_internal"]
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in serializable_config_dict:
|
||||
del serializable_config_dict["base_model_tp_plan"]
|
||||
# Do not serialize `base_model_pp_plan` for now
|
||||
if "base_model_pp_plan" in serializable_config_dict:
|
||||
del serializable_config_dict["base_model_pp_plan"]
|
||||
|
||||
if "_name_or_path" in serializable_config_dict:
|
||||
del serializable_config_dict["_name_or_path"]
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
@ -878,18 +865,6 @@ class PretrainedConfig(PushToHubMixin):
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
if hasattr(self.__class__, "model_type"):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
if "_auto_class" in output:
|
||||
del output["_auto_class"]
|
||||
if "_commit_hash" in output:
|
||||
del output["_commit_hash"]
|
||||
if "_attn_implementation_internal" in output:
|
||||
del output["_attn_implementation_internal"]
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in output:
|
||||
del output["base_model_tp_plan"]
|
||||
# Do not serialize `base_model_pp_plan` for now
|
||||
if "base_model_pp_plan" in output:
|
||||
del output["base_model_pp_plan"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
@ -902,16 +877,14 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
output[key] = value
|
||||
|
||||
self._remove_keys_not_serialized(output)
|
||||
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
else self.quantization_config
|
||||
)
|
||||
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = output.pop("_pre_quantization_dtype", None)
|
||||
|
||||
self.dict_torch_dtype_to_str(output)
|
||||
|
||||
return output
|
||||
@ -1011,6 +984,33 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if isinstance(value, dict):
|
||||
self.dict_torch_dtype_to_str(value)
|
||||
|
||||
def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
|
||||
"""
|
||||
Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
|
||||
Runs recursive check on the dict, to remove from all sub configs.
|
||||
"""
|
||||
if hasattr(self, "quantization_config"):
|
||||
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = d.pop("_pre_quantization_dtype", None)
|
||||
|
||||
if "_auto_class" in d:
|
||||
del d["_auto_class"]
|
||||
if "_commit_hash" in d:
|
||||
del d["_commit_hash"]
|
||||
if "_attn_implementation_internal" in d:
|
||||
del d["_attn_implementation_internal"]
|
||||
# Do not serialize `base_model_tp_plan` for now
|
||||
if "base_model_tp_plan" in d:
|
||||
del d["base_model_tp_plan"]
|
||||
# Do not serialize `base_model_pp_plan` for now
|
||||
if "base_model_pp_plan" in d:
|
||||
del d["base_model_pp_plan"]
|
||||
if "_name_or_path" in d:
|
||||
del d["_name_or_path"]
|
||||
for value in d.values():
|
||||
if isinstance(value, dict):
|
||||
self._remove_keys_not_serialized(value)
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoConfig"):
|
||||
"""
|
||||
|
@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# once the weights have been quantized
|
||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||
# remain a single source of truth
|
||||
config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
|
||||
original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
|
||||
|
||||
def _assign_original_dtype(module):
|
||||
for child in module.children():
|
||||
if isinstance(child, PreTrainedModel):
|
||||
child.config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(child)
|
||||
|
||||
config._pre_quantization_dtype = original_dtype
|
||||
_assign_original_dtype(model)
|
||||
|
||||
# Prepare the full device map
|
||||
if device_map is not None:
|
||||
|
@ -125,6 +125,9 @@ class InternVLVisionAttention(nn.Module):
|
||||
proj_dropout = config.projection_dropout
|
||||
qk_norm = config.use_qk_norm
|
||||
|
||||
# Needed for flash attention
|
||||
self.is_causal = False
|
||||
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
@ -134,9 +137,6 @@ class InternVLVisionAttention(nn.Module):
|
||||
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||
|
||||
# Needed for flash attention
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -344,6 +344,7 @@ class JanusVisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
proj_dropout = config.projection_dropout
|
||||
qk_norm = config.use_qk_norm
|
||||
self.is_causal = False
|
||||
|
||||
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
|
||||
self.num_key_value_groups = 1
|
||||
@ -398,7 +399,7 @@ class JanusVisionAttention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
||||
|
@ -509,6 +509,7 @@ class JanusVisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
proj_dropout = config.projection_dropout
|
||||
qk_norm = config.use_qk_norm
|
||||
self.is_causal = False
|
||||
|
||||
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
|
||||
self.num_key_value_groups = 1
|
||||
@ -563,7 +564,7 @@ class JanusVisionAttention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
||||
|
@ -316,10 +316,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
# todo: yoni - fix or improve the test
|
||||
@unittest.skip("Difference is slightly higher than the threshold")
|
||||
def test_batching_equivalence(self):
|
||||
|
@ -222,8 +222,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
@unittest.skip(
|
||||
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
|
||||
)
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -407,10 +407,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
|
@ -367,10 +367,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
|
@ -302,10 +302,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -331,10 +331,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -302,10 +302,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -334,10 +334,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -331,10 +331,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -378,10 +378,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
||||
def test_generate_methods_with_logits_to_keep(self):
|
||||
super().test_generate_methods_with_logits_to_keep()
|
||||
|
||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
@ -225,10 +225,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
@ -305,10 +305,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user