Ignore non-causal mask in more cases with SDPA (#30138)

* update non-causal mask for sdpa

* add test

* update docstrings

* add one more test

* fix cross attention bug

* gentler atol/rtol
This commit is contained in:
fxmarty 2024-06-03 04:08:41 -07:00 committed by GitHub
parent f4f696255f
commit 221aaec6ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 18 deletions

View File

@ -413,7 +413,7 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
@ -429,36 +429,25 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
batch_size, key_value_length = mask.shape
_, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
if not is_tracing and torch.all(mask == 1):
if tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
return None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
return None
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)

View File

@ -432,7 +432,9 @@ class BertSdpaSelfAttention(BertSelfAttention):
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,

View File

@ -16,7 +16,7 @@ import os
import tempfile
import unittest
from transformers import BertConfig, is_torch_available
from transformers import AutoTokenizer, BertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
CaptureLogger,
@ -747,3 +747,36 @@ class BertModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
def test_sdpa_ignored_mask(self):
pkv = []
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager")
model_sdpa = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="sdpa")
model = model.eval()
model_sdpa = model_sdpa.eval()
for _ in range(model.config.num_hidden_layers):
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
pkv.append([torch.rand(1, num_heads, 3, head_dim), torch.rand(1, num_heads, 3, head_dim)])
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
inp = tokenizer("I am in Paris and", return_tensors="pt")
del inp["attention_mask"]
with torch.no_grad():
res_eager = model(**inp)
res_sdpa = model_sdpa(**inp)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)
# Case where query length != kv_length.
res_eager = model(**inp, past_key_values=pkv)
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)