mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Falcon
] Remove SDPA for falcon to support earlier versions of PyTorch (< 2.0) (#25947)
* remove SDPA for falcon * revert previous behaviour and add warning * nit * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/models/falcon/modeling_falcon.py --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
22a69f1d7d
commit
49b69fe0d4
@ -422,9 +422,19 @@ class FalconAttention(nn.Module):
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
|
||||
if alibi is None:
|
||||
if output_attentions:
|
||||
# F.scaled_dot_product_attention doesn't return the attention weights, so we have
|
||||
# to do it by hand if we want them
|
||||
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
|
||||
# TODO: deprecate this once we add FA2 support in Falcon
|
||||
logger.warning_once(
|
||||
"The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the"
|
||||
" future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call "
|
||||
"`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations."
|
||||
)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
|
||||
)
|
||||
attention_scores = None
|
||||
else:
|
||||
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
|
||||
attention_scores /= math.sqrt(self.head_dim)
|
||||
|
||||
@ -432,11 +442,6 @@ class FalconAttention(nn.Module):
|
||||
attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
|
||||
)
|
||||
attn_output = attention_scores @ value_layer_
|
||||
else:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
|
||||
)
|
||||
attention_scores = None
|
||||
|
||||
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
Loading…
Reference in New Issue
Block a user