Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 (#37332)

* Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1

* fix code format

* add test; replace position_ids with query_states becasue position_ids.shape[0] is always 1

* add assert loss is not nan
This commit is contained in:
efsotr 2025-06-25 15:58:34 +08:00 committed by GitHub
parent ae32f1ad11
commit 3ee72af6b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 2 deletions

View File

@ -385,8 +385,10 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and (
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
elif (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
):
batch_size = query_states.size(0)

View File

@ -4082,6 +4082,45 @@ class ModelTesterMixin:
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
if not hasattr(self, "_torch_compile_train_cls"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
torch.compiler.reset()
torch_dtype = torch.float16
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flash_attention_2"
cls = self._torch_compile_train_cls
model = cls(config).to(device=torch_device, dtype=torch_dtype)
inputs = {
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
}
model = torch.compile(model, fullgraph=True)
# forward compilation
set_seed(42)
loss = model(**inputs).loss
# backward compilation
loss.backward()
assert not loss.isnan().any()
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test