From 3ee72af6b6133be5280a1abcf2cb7b497555f537 Mon Sep 17 00:00:00 2001 From: efsotr <104755879+efsotr@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:58:34 +0800 Subject: [PATCH] Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 (#37332) * Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 * fix code format * add test; replace position_ids with query_states becasue position_ids.shape[0] is always 1 * add assert loss is not nan --- .../modeling_flash_attention_utils.py | 6 ++- tests/test_modeling_common.py | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 03e2922b558..7f3df329432 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -385,8 +385,10 @@ def _flash_attention_forward( # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and ( - max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) + elif ( + position_ids is not None + and query_states.shape[0] == 1 + and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) ): batch_size = query_states.size(0) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f404f996283..f7183089044 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4082,6 +4082,45 @@ class ModelTesterMixin: # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(self): + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + if not hasattr(self, "_torch_compile_train_cls"): + self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.") + + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + torch.compiler.reset() + torch_dtype = torch.float16 + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flash_attention_2" + cls = self._torch_compile_train_cls + model = cls(config).to(device=torch_device, dtype=torch_dtype) + + inputs = { + "input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device), + "labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device), + } + + model = torch.compile(model, fullgraph=True) + # forward compilation + set_seed(42) + loss = model(**inputs).loss + # backward compilation + loss.backward() + + assert not loss.isnan().any() + @require_flash_attn @require_torch_gpu @mark.flash_attn_test