mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-26 07:49:01 +06:00
Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 (#37332)
* Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 * fix code format * add test; replace position_ids with query_states becasue position_ids.shape[0] is always 1 * add assert loss is not nan
This commit is contained in:
parent
ae32f1ad11
commit
3ee72af6b6
@ -385,8 +385,10 @@ def _flash_attention_forward(
|
|||||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||||
elif position_ids is not None and (
|
elif (
|
||||||
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
position_ids is not None
|
||||||
|
and query_states.shape[0] == 1
|
||||||
|
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||||
):
|
):
|
||||||
batch_size = query_states.size(0)
|
batch_size = query_states.size(0)
|
||||||
|
|
||||||
|
@ -4082,6 +4082,45 @@ class ModelTesterMixin:
|
|||||||
# with attention mask
|
# with attention mask
|
||||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_can_compile_with_attention_mask_None_without_graph_break(self):
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
if not hasattr(self, "_torch_compile_train_cls"):
|
||||||
|
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")
|
||||||
|
|
||||||
|
if not self.has_attentions:
|
||||||
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
|
if not is_torch_fp16_available_on_device(torch_device):
|
||||||
|
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||||
|
|
||||||
|
torch.compiler.reset()
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config._attn_implementation = "flash_attention_2"
|
||||||
|
cls = self._torch_compile_train_cls
|
||||||
|
model = cls(config).to(device=torch_device, dtype=torch_dtype)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||||
|
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = torch.compile(model, fullgraph=True)
|
||||||
|
# forward compilation
|
||||||
|
set_seed(42)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
# backward compilation
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert not loss.isnan().any()
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
Loading…
Reference in New Issue
Block a user