Fix is_causal fail with compile (#36374)

fix
This commit is contained in:
Cyril Vallez 2025-02-25 10:44:56 +01:00 committed by GitHub
parent bc65f3fc1c
commit 401543a825
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,8 +42,9 @@ def sdpa_attention_forward(
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
is_causal = causal_mask is None and query.shape[2] > 1
is_causal = query.shape[2] > 1 and causal_mask is None
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.