Do not drop mask with SDPA for more cases (#30311)

* overlooked

* style

* cleaner
This commit is contained in:
fxmarty 2024-04-18 14:37:09 +02:00 committed by GitHub
parent acab997bef
commit 63c5e27efb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -319,8 +319,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
ignore_causal_mask = False
if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
if (
not is_tracing
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through
if len(attention_mask.shape) == 4: