mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
bc65f3fc1c
commit
401543a825
@ -42,8 +42,9 @@ def sdpa_attention_forward(
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
|
||||
if is_causal is None:
|
||||
is_causal = causal_mask is None and query.shape[2] > 1
|
||||
is_causal = query.shape[2] > 1 and causal_mask is None
|
||||
|
||||
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
|
||||
# We convert it to a bool for the SDPA kernel that only accepts bools.
|
||||
|
Loading…
Reference in New Issue
Block a user