[FlexAttn] Fix models with unique characteristics (#38433)

* fix

* style

* check

* check 2

* add deepseek workaround
This commit is contained in:
Anton Vlasjuk 2025-06-04 13:37:28 +02:00 committed by GitHub
parent ff3fad61e3
commit 1dc619e59f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 0 deletions

View File

@ -24,6 +24,7 @@ from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_large_accelerator,
require_torch_sdpa,
slow,
@ -494,6 +495,35 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)
@require_torch_gpu
def test_flex_attention_with_grads(self):
"""
Overwriting as the namings/functionality on the attention part are different; for now it's more of a unique model.
Original issue is also due to dimensionalities, here specifically due to dims not being a multiple of 2.
"""
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
# Disable dropout
config.attention_dropout = 0.0
# Deepseek 3 specific - manipulate nope and adjust calculated total head dim
config.qk_nope_head_dim = 16
config.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
model = model_class(config).to(device=torch_device)
self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
if config.is_encoder_decoder:
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(**dummy_inputs)
@require_torch_accelerator
class DeepseekV3IntegrationTest(unittest.TestCase):

View File

@ -578,6 +578,28 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch_gpu
def test_flex_attention_with_grads(self):
"""
Overwriting as the base hidden size is big enough for compile.
Manipulation of dims causes issues due to other constraints not being satisfied anymore.
"""
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
model = model_class(config).to(device=torch_device)
self.assertTrue(model.config._attn_implementation == "flex_attention")
# Elaborate workaround for encoder-decoder models as some do not specify their main input
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
if config.is_encoder_decoder:
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
_ = model(**dummy_inputs)
@require_torch
class Zamba2ModelIntegrationTest(unittest.TestCase):