mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 03:58:25 +06:00
parent
98dd842339
commit
924c46d40c
@ -310,7 +310,7 @@ class CohereAttention(nn.Module):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 Llama->Cohere
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
|
||||
class CohereFlashAttention2(CohereAttention):
|
||||
"""
|
||||
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
|
||||
@ -326,6 +326,7 @@ class CohereFlashAttention2(CohereAttention):
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
Loading…
Reference in New Issue
Block a user