mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 03:58:25 +06:00
parent
98dd842339
commit
924c46d40c
@ -310,7 +310,7 @@ class CohereAttention(nn.Module):
|
|||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 Llama->Cohere
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
|
||||||
class CohereFlashAttention2(CohereAttention):
|
class CohereFlashAttention2(CohereAttention):
|
||||||
"""
|
"""
|
||||||
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
|
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
|
||||||
@ -326,6 +326,7 @@ class CohereFlashAttention2(CohereAttention):
|
|||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
Loading…
Reference in New Issue
Block a user