From 1dc619e59f4f1103a30a303404a2b0990d45f07c Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Wed, 4 Jun 2025 13:37:28 +0200 Subject: [PATCH] [`FlexAttn`] Fix models with unique characteristics (#38433) * fix * style * check * check 2 * add deepseek workaround --- .../deepseek_v3/test_modeling_deepseek_v3.py | 30 +++++++++++++++++++ tests/models/zamba2/test_modeling_zamba2.py | 22 ++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index c31072bf848..7677b909c3b 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -24,6 +24,7 @@ from transformers.testing_utils import ( require_read_token, require_torch, require_torch_accelerator, + require_torch_gpu, require_torch_large_accelerator, require_torch_sdpa, slow, @@ -494,6 +495,35 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", ) + @require_torch_gpu + def test_flex_attention_with_grads(self): + """ + Overwriting as the namings/functionality on the attention part are different; for now it's more of a unique model. + Original issue is also due to dimensionalities, here specifically due to dims not being a multiple of 2. + """ + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flex_attention" + + # Disable dropout + config.attention_dropout = 0.0 + + # Deepseek 3 specific - manipulate nope and adjust calculated total head dim + config.qk_nope_head_dim = 16 + config.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + model = model_class(config).to(device=torch_device) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + if config.is_encoder_decoder: + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device) + + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) + _ = model(**dummy_inputs) + @require_torch_accelerator class DeepseekV3IntegrationTest(unittest.TestCase): diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 894cde8be33..4c12ac47323 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -578,6 +578,28 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_new_cache_format(self, num_beams, do_sample): pass + @require_torch_gpu + def test_flex_attention_with_grads(self): + """ + Overwriting as the base hidden size is big enough for compile. + Manipulation of dims causes issues due to other constraints not being satisfied anymore. + """ + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flex_attention" + + model = model_class(config).to(device=torch_device) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + if config.is_encoder_decoder: + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device) + + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) + _ = model(**dummy_inputs) + @require_torch class Zamba2ModelIntegrationTest(unittest.TestCase):