mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
chore: fix typos in language models (#36586)
* chore: fix typos in language models * chore: fix typos in mistral model * chore: fix model copy from issue * chore: fix model copy from issue * chore: fix model copy from issue * chore: fix model copy from issue * chore: fix model copy from issue
This commit is contained in:
parent
a929c466d0
commit
af9b2eaa54
@ -1094,7 +1094,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1399,7 +1399,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1140,7 +1140,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -201,7 +201,7 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -298,7 +298,7 @@ class BartFlashAttention2(BartAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -274,7 +274,7 @@ class FlaxBartAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -263,7 +263,7 @@ class FlaxBertSelfAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -284,7 +284,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -262,7 +262,7 @@ class FlaxBlenderbotAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -273,7 +273,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -824,7 +824,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -187,7 +187,7 @@ class FlaxBloomAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -376,7 +376,7 @@ class ChameleonFlashAttention2(ChameleonAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1470,7 +1470,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -412,7 +412,7 @@ class CLIPFlashAttention2(CLIPAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -82,7 +82,7 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
||||
contains everything needed to load the tokenizer.
|
||||
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
||||
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
||||
Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
||||
spaces.
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
|
@ -667,7 +667,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -744,7 +744,7 @@ class CohereModel(CoherePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -406,7 +406,7 @@ class CohereTokenizerFast(PreTrainedTokenizerFast):
|
||||
conversation (Union[List[Dict[str, str]]]): A list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
documents (List[Dict[str, str]): A list of dicts, representing documents or tool outputs to ground your
|
||||
generation on. A document is a semistructured dict, wiht a string to string mapping. Common fields are
|
||||
generation on. A document is a semistructured dict, with a string to string mapping. Common fields are
|
||||
`url`, `title`, `snippet` etc but should be descriptive of the key. They will get rendered into the prompt.
|
||||
citation_mode: either "accurate" (prompt the model to generate an answer first, then rewrite it with citation
|
||||
spans in) or "fast", where the prompt instructs the model to generate an answer with citations in directly.
|
||||
|
@ -745,7 +745,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -493,7 +493,7 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -322,7 +322,7 @@ class DbrxFlashAttention2(DbrxAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1199,7 +1199,7 @@ class DbrxModel(DbrxPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -244,7 +244,7 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -983,7 +983,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -174,7 +174,7 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -249,7 +249,7 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -230,7 +230,7 @@ class FlaxElectraSelfAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -1562,7 +1562,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -470,7 +470,7 @@ class FalconFlashAttention2(FalconAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1126,7 +1126,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -234,7 +234,7 @@ class FlaxGemmaAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -716,7 +716,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -757,7 +757,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -725,7 +725,7 @@ class GlmModel(GlmPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -162,7 +162,7 @@ class FlaxGPT2Attention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -282,7 +282,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -148,7 +148,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -278,7 +278,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -876,7 +876,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -719,7 +719,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -746,7 +746,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -174,7 +174,7 @@ class FlaxGPTJAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -270,7 +270,7 @@ class GPTJFlashAttention2(GPTJAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -975,7 +975,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -728,7 +728,7 @@ class GraniteModel(GranitePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -525,7 +525,7 @@ class GraniteMoeFlashAttention2(GraniteMoeAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1202,7 +1202,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -402,7 +402,7 @@ class GraniteMoeSharedFlashAttention2(GraniteMoeSharedAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1146,7 +1146,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -712,7 +712,7 @@ class HeliumModel(HeliumPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -567,7 +567,7 @@ class HubertFlashAttention2(HubertAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -1447,7 +1447,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -277,7 +277,7 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -872,7 +872,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -277,7 +277,7 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -390,7 +390,7 @@ class JambaFlashAttention2(JambaAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -676,7 +676,7 @@ class JetMoeFlashAttention2(JetMoeAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1208,7 +1208,7 @@ class JetMoeModel(JetMoePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -228,7 +228,7 @@ class FlaxLlamaAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -714,7 +714,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -431,7 +431,7 @@ class FlaxLongT5Attention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -1684,7 +1684,7 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -352,7 +352,7 @@ class M2M100FlashAttention2(M2M100Attention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -285,7 +285,7 @@ class FlaxMarianAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -286,7 +286,7 @@ class FlaxMBartAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -295,7 +295,7 @@ class MBartFlashAttention2(MBartAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -602,7 +602,7 @@ class MimiFlashAttention2(MimiAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1170,7 +1170,7 @@ class MimiTransformerModel(nn.Module):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -112,7 +112,7 @@ def get_concat_dim(key):
|
||||
|
||||
|
||||
def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig):
|
||||
"""Convert the state dict, when a single `nn.Module` is sharded accross different files."""
|
||||
"""Convert the state dict, when a single `nn.Module` is sharded across different files."""
|
||||
new_dict = {}
|
||||
|
||||
num_shards = len(loaded_shards)
|
||||
|
@ -254,7 +254,7 @@ class FlaxMistralAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -703,7 +703,7 @@ class MistralModel(MistralPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -221,7 +221,7 @@ class MistralModel(LlamaModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -837,7 +837,7 @@ class MixtralModel(MixtralPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1162,7 +1162,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1078,7 +1078,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -571,7 +571,7 @@ class MoshiFlashAttention2(MoshiAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1400,7 +1400,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
@ -1714,7 +1714,7 @@ class MoshiModel(MoshiPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1276,7 +1276,7 @@ class MT5Stack(MT5PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -328,7 +328,7 @@ class MusicgenFlashAttention2(MusicgenAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -344,7 +344,7 @@ class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -315,7 +315,7 @@ class NemotronFlashAttention2(NemotronAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -963,7 +963,7 @@ class NemotronModel(NemotronPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -690,7 +690,7 @@ class OlmoModel(OlmoPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -691,7 +691,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -402,7 +402,7 @@ class OlmoeFlashAttention2(OlmoeAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1122,7 +1122,7 @@ class OlmoeModel(OlmoePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -150,7 +150,7 @@ class FlaxOPTAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -199,7 +199,7 @@ class OptFlashAttention2(OPTAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -723,7 +723,7 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -77,7 +77,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
|
@ -278,7 +278,7 @@ class FlaxPegasusAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -763,7 +763,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -688,7 +688,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -778,7 +778,7 @@ class Phi3Model(Phi3PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -88,7 +88,7 @@ class PhimoeConfig(PretrainedConfig):
|
||||
num_local_experts (`int`, *optional*, defaults to 16):
|
||||
Number of experts per Sparse MLP layer.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss. See [here]() for more details
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
|
@ -1284,7 +1284,7 @@ class PhimoeModel(PhimoePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1671,7 +1671,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -1084,7 +1084,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -716,7 +716,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -243,7 +243,7 @@ class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
# TODO: @raushan update config in the hub
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
|
@ -604,7 +604,7 @@ class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids
|
||||
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
@ -646,7 +646,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
||||
Explanation:
|
||||
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
||||
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
||||
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
||||
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
||||
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
||||
@ -815,7 +815,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1349,7 +1349,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
@ -1564,7 +1564,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
width position_ids: [0, 1, 2, 3, 4]
|
||||
|
||||
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
||||
and 1D rotary position embeddin for text part.
|
||||
and 1D rotary position embedding for text part.
|
||||
Examples:
|
||||
Temporal (Time): 3 patches, representing different segments of the video in time.
|
||||
Height: 2 patches, dividing each frame vertically.
|
||||
|
@ -432,7 +432,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
width position_ids: [0, 1, 2, 3, 4]
|
||||
|
||||
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
||||
and 1D rotary position embeddin for text part.
|
||||
and 1D rotary position embedding for text part.
|
||||
Examples:
|
||||
Temporal (Time): 3 patches, representing different segments of the video in time.
|
||||
Height: 2 patches, dividing each frame vertically.
|
||||
|
@ -227,7 +227,7 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -125,7 +125,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
|
@ -410,7 +410,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1172,7 +1172,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
|
@ -232,7 +232,7 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
# TODO: @raushan update config in the hub
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
|
@ -110,7 +110,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spacial patch size of the vision encoder.
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
|
@ -86,7 +86,7 @@ class Qwen2VLFastImageProcessorInitKwargs(DefaultFastImageProcessorInitKwargs):
|
||||
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spacial patch size of the vision encoder.
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
|
@ -140,7 +140,7 @@ class Qwen2VLRotaryEmbedding(nn.Module):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
||||
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for the grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
@ -174,7 +174,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
||||
Explanation:
|
||||
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
||||
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
|
||||
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
|
||||
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
||||
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
||||
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
||||
@ -628,7 +628,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
@ -1296,7 +1296,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
The device to place the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
@ -1461,7 +1461,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
Explanation:
|
||||
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
||||
|
||||
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
||||
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
||||
Examples:
|
||||
input_ids: [T T T T T], here T is for text.
|
||||
temporal position_ids: [0, 1, 2, 3, 4]
|
||||
@ -1469,7 +1469,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
width position_ids: [0, 1, 2, 3, 4]
|
||||
|
||||
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
||||
and 1D rotary position embeddin for text part.
|
||||
and 1D rotary position embedding for text part.
|
||||
Examples:
|
||||
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
||||
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
||||
|
@ -224,7 +224,7 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -227,7 +227,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
states from previous steps. This function is slightly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
|
@ -40,7 +40,7 @@ class JiebaPreTokenizer:
|
||||
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
|
||||
splits = []
|
||||
|
||||
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
|
||||
# this code slice normalized_string is too slow (6s) but test_alignment_methods can pass
|
||||
for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
|
||||
if token in self.vocab:
|
||||
splits.append(normalized_string[start:end])
|
||||
@ -52,7 +52,7 @@ class JiebaPreTokenizer:
|
||||
splits.append(normalized_string[start:end])
|
||||
start = end
|
||||
|
||||
# this code test_alignement_methods can't pass but fast (300ms)
|
||||
# this code test_alignment_methods can't pass but fast (300ms)
|
||||
# for token in self.jieba.cut(str(normalized_string), False):
|
||||
# if token in self.vocab:
|
||||
# splits.append(NormalizedString(token))
|
||||
|
@ -567,7 +567,7 @@ class SEWFlashAttention2(SEWAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
@ -449,7 +449,7 @@ class SiglipFlashAttention2(SiglipAttention):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user