mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 (#37332)
* Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 * fix code format * add test; replace position_ids with query_states becasue position_ids.shape[0] is always 1 * add assert loss is not nan
This commit is contained in:
parent
ae32f1ad11
commit
3ee72af6b6
@ -385,8 +385,10 @@ def _flash_attention_forward(
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
elif position_ids is not None and (
|
||||
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||
elif (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
):
|
||||
batch_size = query_states.size(0)
|
||||
|
||||
|
@ -4082,6 +4082,45 @@ class ModelTesterMixin:
|
||||
# with attention mask
|
||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
if not hasattr(self, "_torch_compile_train_cls"):
|
||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
|
||||
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not is_torch_fp16_available_on_device(torch_device):
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
torch.compiler.reset()
|
||||
torch_dtype = torch.float16
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
cls = self._torch_compile_train_cls
|
||||
model = cls(config).to(device=torch_device, dtype=torch_dtype)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
}
|
||||
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
# forward compilation
|
||||
set_seed(42)
|
||||
loss = model(**inputs).loss
|
||||
# backward compilation
|
||||
loss.backward()
|
||||
|
||||
assert not loss.isnan().any()
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
Loading…
Reference in New Issue
Block a user