[Falcon] Remove SDPA for falcon to support earlier versions of PyTorch (< 2.0) (#25947)

* remove SDPA for falcon

* revert previous behaviour and add warning

* nit

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/models/falcon/modeling_falcon.py

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Younes Belkada 2023-09-04 20:34:04 +02:00 committed by GitHub
parent 22a69f1d7d
commit 49b69fe0d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -422,9 +422,19 @@ class FalconAttention(nn.Module):
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
if alibi is None:
if output_attentions:
# F.scaled_dot_product_attention doesn't return the attention weights, so we have
# to do it by hand if we want them
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
# TODO: deprecate this once we add FA2 support in Falcon
logger.warning_once(
"The current implementation of Falcon calls `torch.scaled_dot_product_attention` directly, this will be deprecated in the"
" future in favor of the `BetterTransformer` API. Please install the latest optimum library with `pip install -U optimum` and call "
"`model.to_bettertransformer()` to benefit from `torch.scaled_dot_product_attention` and future performance optimizations."
)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
)
attention_scores = None
else:
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)
@ -432,11 +442,6 @@ class FalconAttention(nn.Module):
attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
)
attn_output = attention_scores @ value_layer_
else:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
)
attention_scores = None
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)