mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Finish adding support for torch.compile dynamic shapes (#30919)
add torch.compile dynamic support
This commit is contained in:
parent
6739e1d261
commit
046c2ad792
@ -615,13 +615,17 @@ class DbrxSdpaAttention(DbrxAttention):
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attn_pdrop if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -725,14 +725,18 @@ class JambaSdpaAttention(JambaAttention):
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
@ -624,15 +624,17 @@ class JetMoeSdpaAttention(JetMoeAttention):
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
|
||||
# relying on the `is_causal` argument.
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
Loading…
Reference in New Issue
Block a user