mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 10:38:23 +06:00
🧹 Remove deprecated RotaryEmbedding parts in the Attention layers (#34858)
* update * style * fix missing args * remove last trace of old rope classes * remove deprecated copied from * fix copies * trigger CIs * post rebase clean-up * reverse mistral * cleanup after dropping commits * Add comment
This commit is contained in:
parent
9094b87dd4
commit
d363e71d0e
@ -228,9 +228,6 @@ class DummyAttention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
|
||||||
self.rotary_emb = DummyRotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -240,7 +237,7 @@ class DummyAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -254,15 +251,6 @@ class DummyAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -326,7 +314,7 @@ class DummyFlashAttention2(DummyAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
@ -350,15 +338,6 @@ class DummyFlashAttention2(DummyAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -441,7 +420,7 @@ class DummySdpaAttention(DummyAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -472,15 +451,6 @@ class DummySdpaAttention(DummyAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -551,7 +521,7 @@ class DummyDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -228,9 +228,6 @@ class Multimodal1TextAttention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
|
||||||
self.rotary_emb = Multimodal1TextRotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -240,7 +237,7 @@ class Multimodal1TextAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -254,15 +251,6 @@ class Multimodal1TextAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -326,7 +314,7 @@ class Multimodal1TextFlashAttention2(Multimodal1TextAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
@ -350,15 +338,6 @@ class Multimodal1TextFlashAttention2(Multimodal1TextAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -441,7 +420,7 @@ class Multimodal1TextSdpaAttention(Multimodal1TextAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -472,15 +451,6 @@ class Multimodal1TextSdpaAttention(Multimodal1TextAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -553,7 +523,7 @@ class Multimodal1TextDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -228,9 +228,6 @@ class SuperAttention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
|
||||||
self.rotary_emb = SuperRotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -240,7 +237,7 @@ class SuperAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -254,15 +251,6 @@ class SuperAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -326,7 +314,7 @@ class SuperFlashAttention2(SuperAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
@ -350,15 +338,6 @@ class SuperFlashAttention2(SuperAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -441,7 +420,7 @@ class SuperSdpaAttention(SuperAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -472,15 +451,6 @@ class SuperSdpaAttention(SuperAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -551,7 +521,7 @@ class SuperDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -115,8 +115,6 @@ class ChameleonRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon
|
|
||||||
# TODO(joao): add me back asap :)
|
|
||||||
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
|
class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
|
||||||
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@ -127,8 +125,6 @@ class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
|
|||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon
|
|
||||||
# TODO(joao): add me back asap :)
|
|
||||||
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
|
class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
|
||||||
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
@ -283,9 +283,6 @@ class CohereAttention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
|
||||||
self.rotary_emb = CohereRotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -295,7 +292,7 @@ class CohereAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -314,15 +311,6 @@ class CohereAttention(nn.Module):
|
|||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -389,7 +377,7 @@ class CohereFlashAttention2(CohereAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
@ -415,15 +403,6 @@ class CohereFlashAttention2(CohereAttention):
|
|||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -502,7 +481,7 @@ class CohereSdpaAttention(CohereAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -518,6 +497,7 @@ class CohereSdpaAttention(CohereAttention):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -536,15 +516,6 @@ class CohereSdpaAttention(CohereAttention):
|
|||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -615,7 +586,7 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -197,33 +197,6 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
|
|
||||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
|
||||||
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`FalconLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`FalconRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "linear"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
|
|
||||||
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
|
||||||
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`FalconDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`FalconRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
|
||||||
"__init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "dynamic"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||||
batch_size, seq_length = attention_mask.shape
|
batch_size, seq_length = attention_mask.shape
|
||||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||||
@ -388,7 +361,7 @@ class FalconAttention(nn.Module):
|
|||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||||
@ -402,15 +375,6 @@ class FalconAttention(nn.Module):
|
|||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||||
|
|
||||||
if alibi is None:
|
if alibi is None:
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_layer, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
||||||
|
|
||||||
@ -548,7 +512,7 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||||
@ -562,15 +526,6 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||||
|
|
||||||
if alibi is None:
|
if alibi is None:
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_layer, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
||||||
|
|
||||||
@ -695,7 +650,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -227,7 +227,7 @@ class GlmAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -303,7 +303,7 @@ class GlmFlashAttention2(GlmAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
@ -402,7 +402,7 @@ class GlmSdpaAttention(GlmAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -503,7 +503,7 @@ class GlmDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -311,7 +311,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
padding_mask: Optional[torch.Tensor] = None,
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
bsz, seq_len, _ = hidden_states.shape
|
bsz, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
# Compute QKV
|
# Compute QKV
|
||||||
# Attention heads [batch, seq_len, hidden_size]
|
# Attention heads [batch, seq_len, hidden_size]
|
||||||
@ -427,15 +427,6 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
key_rot = key[..., : self.rotary_ndims]
|
key_rot = key[..., : self.rotary_ndims]
|
||||||
key_pass = key[..., self.rotary_ndims :]
|
key_pass = key[..., self.rotary_ndims :]
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||||
query = torch.cat((query, query_pass), dim=-1)
|
query = torch.cat((query, query_pass), dim=-1)
|
||||||
@ -583,33 +574,6 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
|
|
||||||
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
|
||||||
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`GPTNeoXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`GPTNeoXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "linear"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
|
|
||||||
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
|
||||||
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`GPTNeoXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`GPTNeoXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
|
||||||
"__init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "dynamic"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@ -688,7 +652,7 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
layer_past: Optional[Cache] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
attention_layer_outputs = self.attention(
|
attention_layer_outputs = self.attention(
|
||||||
self.input_layernorm(hidden_states),
|
self.input_layernorm(hidden_states),
|
||||||
|
@ -105,7 +105,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
# Compute QKV
|
# Compute QKV
|
||||||
# Attention heads [batch, seq_len, hidden_size]
|
# Attention heads [batch, seq_len, hidden_size]
|
||||||
@ -128,15 +128,6 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||||||
key_rot = key[..., : self.rotary_ndims]
|
key_rot = key[..., : self.rotary_ndims]
|
||||||
key_pass = key[..., self.rotary_ndims :]
|
key_pass = key[..., self.rotary_ndims :]
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||||
query = torch.cat((query, query_pass), dim=-1)
|
query = torch.cat((query, query_pass), dim=-1)
|
||||||
@ -415,7 +406,7 @@ class GPTNeoXJapaneseLayer(nn.Module):
|
|||||||
layer_past: Optional[Cache] = None,
|
layer_past: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
ln_out = self.input_layernorm(hidden_states)
|
ln_out = self.input_layernorm(hidden_states)
|
||||||
|
@ -242,7 +242,7 @@ class GraniteAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -318,7 +318,7 @@ class GraniteFlashAttention2(GraniteAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
@ -417,7 +417,7 @@ class GraniteSdpaAttention(GraniteAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -520,7 +520,7 @@ class GraniteDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -458,7 +458,7 @@ class GraniteMoeAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -535,7 +535,7 @@ class GraniteMoeFlashAttention2(GraniteMoeAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
|
|
||||||
@ -635,7 +635,7 @@ class GraniteMoeSdpaAttention(GraniteMoeAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -739,7 +739,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
output_router_logits: Optional[bool] = False,
|
output_router_logits: Optional[bool] = False,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -168,31 +168,6 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|
||||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "linear"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|
||||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
|
||||||
"__init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "dynamic"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@ -284,9 +259,6 @@ class LlamaAttention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||||
|
|
||||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -296,7 +268,7 @@ class LlamaAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -310,15 +282,6 @@ class LlamaAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -382,7 +345,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
@ -406,15 +369,6 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -497,7 +451,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
@ -528,15 +482,6 @@ class LlamaSdpaAttention(LlamaAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -607,7 +552,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -314,10 +314,6 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += cache_position[0]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
@ -858,7 +858,7 @@ class MllamaSelfAttentionDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -536,7 +536,7 @@ class NemotronDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -105,8 +105,6 @@ class OlmoRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
|
|
||||||
# TODO(joao): add me back asap :)
|
|
||||||
class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
|
class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
|
||||||
"""OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@ -117,8 +115,6 @@ class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
|
|||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
|
|
||||||
# TODO(joao): add me back asap :)
|
|
||||||
class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
|
class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
|
||||||
"""OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
@ -87,8 +87,6 @@ class Olmo2RotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo2
|
|
||||||
# TODO(joao): add me back asap :)
|
|
||||||
class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding):
|
class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding):
|
||||||
"""Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
"""Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
@ -99,8 +97,6 @@ class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding):
|
|||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo2
|
|
||||||
# TODO(joao): add me back asap :)
|
|
||||||
class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding):
|
class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding):
|
||||||
"""Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
"""Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
@ -143,33 +143,6 @@ class PersimmonRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
|
|
||||||
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
|
||||||
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`PersimmonLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`PersimmonRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "linear"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
|
|
||||||
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
|
||||||
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`PersimmonDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`PersimmonRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
|
||||||
"__init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "dynamic"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
@ -286,7 +259,7 @@ class PersimmonAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -305,15 +278,6 @@ class PersimmonAttention(nn.Module):
|
|||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -390,7 +354,7 @@ class PersimmonDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,33 +147,6 @@ class PhiRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
|
|
||||||
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
|
||||||
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`PhiLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`PhiRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "linear"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
|
|
||||||
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
|
||||||
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`PhiDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`PhiRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
|
||||||
"__init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "dynamic"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
@ -294,7 +267,7 @@ class PhiAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -310,15 +283,6 @@ class PhiAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -406,7 +370,7 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# PhiFlashAttention2 attention does not support output_attentions
|
# PhiFlashAttention2 attention does not support output_attentions
|
||||||
@ -430,15 +394,6 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -542,7 +497,7 @@ class PhiSdpaAttention(PhiAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -559,6 +514,8 @@ class PhiSdpaAttention(PhiAttention):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -575,15 +532,6 @@ class PhiSdpaAttention(PhiAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -671,7 +619,7 @@ class PhiDecoderLayer(nn.Module):
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -284,7 +284,7 @@ class Qwen2Attention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -296,15 +296,6 @@ class Qwen2Attention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -370,7 +361,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -382,15 +373,6 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -479,7 +461,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -494,6 +476,8 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -506,15 +490,6 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -590,7 +565,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -368,7 +368,7 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -380,17 +380,7 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@ -457,7 +447,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -469,15 +459,6 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -567,7 +548,7 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -582,6 +563,8 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -594,15 +577,6 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -742,7 +716,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
output_router_logits: Optional[bool] = False,
|
output_router_logits: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -537,7 +537,7 @@ class Qwen2VLAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -549,15 +549,6 @@ class Qwen2VLAttention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
@ -630,7 +621,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -643,17 +634,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
|||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
)
|
)
|
||||||
@ -742,7 +723,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -758,6 +739,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -770,15 +752,6 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
|||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
@ -856,7 +829,7 @@ class Qwen2VLDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -149,33 +149,6 @@ class StableLmRotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
|
|
||||||
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|
||||||
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`StableLmLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`StableLmRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "linear"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
|
|
||||||
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|
||||||
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
logger.warning_once(
|
|
||||||
"`StableLmDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
|
||||||
"`StableLmRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
|
||||||
"__init__)."
|
|
||||||
)
|
|
||||||
kwargs["rope_type"] = "dynamic"
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
@ -307,7 +280,7 @@ class StableLmAttention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -323,15 +296,6 @@ class StableLmAttention(nn.Module):
|
|||||||
query_states = self.q_layernorm(query_states)
|
query_states = self.q_layernorm(query_states)
|
||||||
key_states = self.k_layernorm(key_states)
|
key_states = self.k_layernorm(key_states)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -403,7 +367,7 @@ class StableLmSdpaAttention(StableLmAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -418,6 +382,8 @@ class StableLmSdpaAttention(StableLmAttention):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -434,15 +400,6 @@ class StableLmSdpaAttention(StableLmAttention):
|
|||||||
query_states = self.q_layernorm(query_states)
|
query_states = self.q_layernorm(query_states)
|
||||||
key_states = self.k_layernorm(key_states)
|
key_states = self.k_layernorm(key_states)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -533,7 +490,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# StableLmFlashAttention2 attention does not support output_attentions
|
# StableLmFlashAttention2 attention does not support output_attentions
|
||||||
@ -557,15 +514,6 @@ class StableLmFlashAttention2(StableLmAttention):
|
|||||||
query_states = self.q_layernorm(query_states)
|
query_states = self.q_layernorm(query_states)
|
||||||
key_states = self.k_layernorm(key_states)
|
key_states = self.k_layernorm(key_states)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
# Partial rotary embedding
|
# Partial rotary embedding
|
||||||
@ -650,7 +598,7 @@ class StableLmDecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -251,8 +251,6 @@ class Starcoder2Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
|
||||||
|
|
||||||
self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -262,7 +260,7 @@ class Starcoder2Attention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -274,15 +272,6 @@ class Starcoder2Attention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -346,7 +335,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -358,15 +347,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -446,7 +426,7 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -461,6 +441,8 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -473,15 +455,6 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -555,7 +528,7 @@ class Starcoder2DecoderLayer(nn.Module):
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -117,8 +117,6 @@ class Starcoder2Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
|
||||||
|
|
||||||
self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -128,7 +126,7 @@ class Starcoder2Attention(nn.Module):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -140,15 +138,6 @@ class Starcoder2Attention(nn.Module):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -212,7 +201,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
):
|
):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -224,15 +213,6 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
@ -312,7 +292,7 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
@ -327,6 +307,8 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
@ -339,15 +321,6 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
|||||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if position_embeddings is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
||||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
||||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
||||||
"removed and `position_embeddings` will be mandatory."
|
|
||||||
)
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
||||||
else:
|
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
@ -50,8 +50,6 @@ if is_torch_available():
|
|||||||
FalconModel,
|
FalconModel,
|
||||||
)
|
)
|
||||||
from transformers.models.falcon.modeling_falcon import (
|
from transformers.models.falcon.modeling_falcon import (
|
||||||
FalconDynamicNTKScalingRotaryEmbedding,
|
|
||||||
FalconLinearScalingRotaryEmbedding,
|
|
||||||
FalconRotaryEmbedding,
|
FalconRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -484,11 +482,12 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
# Sanity check linear RoPE scaling
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
linear_scaling_rope = FalconLinearScalingRotaryEmbedding(
|
linear_scaling_rope = FalconRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="linear",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
@ -502,11 +501,12 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
# Sanity check Dynamic NTK RoPE scaling
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding(
|
ntk_scaling_rope = FalconRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="dynamic",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
@ -37,11 +37,7 @@ if is_torch_available():
|
|||||||
GPTNeoXForTokenClassification,
|
GPTNeoXForTokenClassification,
|
||||||
GPTNeoXModel,
|
GPTNeoXModel,
|
||||||
)
|
)
|
||||||
from transformers.models.gpt_neox.modeling_gpt_neox import (
|
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
|
||||||
GPTNeoXDynamicNTKScalingRotaryEmbedding,
|
|
||||||
GPTNeoXLinearScalingRotaryEmbedding,
|
|
||||||
GPTNeoXRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GPTNeoXModelTester:
|
class GPTNeoXModelTester:
|
||||||
@ -400,11 +396,12 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
# Sanity check linear RoPE scaling
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding(
|
linear_scaling_rope = GPTNeoXRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rotary_emb_base,
|
base=config.rotary_emb_base,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="linear",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
@ -418,11 +415,12 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
# Sanity check Dynamic NTK RoPE scaling
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding(
|
ntk_scaling_rope = GPTNeoXRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rotary_emb_base,
|
base=config.rotary_emb_base,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="dynamic",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
@ -51,7 +51,7 @@ if is_torch_available():
|
|||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding, LlamaRotaryEmbedding
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelTester:
|
class LlamaModelTester:
|
||||||
@ -489,43 +489,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
def test_rope_class_retrocompatibility(self):
|
|
||||||
# Delete me when we remove compatibility for the old API :)
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": 10}
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Old API -- under the hood, "type": "linear" is set and `LlamaRotaryEmbedding` is called
|
|
||||||
old_api_rope = LlamaLinearScalingRotaryEmbedding(
|
|
||||||
config.hidden_size // config.num_attention_heads,
|
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
|
||||||
base=config.rope_theta,
|
|
||||||
scaling_factor=scaling_factor,
|
|
||||||
).to(torch_device)
|
|
||||||
old_cos_short, old_sin_short = old_api_rope(x, position_ids_short)
|
|
||||||
old_cos_long, old_sin_long = old_api_rope(x, position_ids_long)
|
|
||||||
|
|
||||||
# New API
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
new_api_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
new_cos_short, new_sin_short = new_api_rope(x, position_ids_short)
|
|
||||||
new_cos_long, new_sin_long = new_api_rope(x, position_ids_long)
|
|
||||||
|
|
||||||
# The results should match
|
|
||||||
torch.testing.assert_close(old_cos_short, new_cos_short)
|
|
||||||
torch.testing.assert_close(old_sin_short, new_sin_short)
|
|
||||||
torch.testing.assert_close(old_cos_long, new_cos_long)
|
|
||||||
torch.testing.assert_close(old_sin_long, new_sin_long)
|
|
||||||
|
|
||||||
def test_model_loading_old_rope_configs(self):
|
def test_model_loading_old_rope_configs(self):
|
||||||
def _reinitialize_config(base_config, new_kwargs):
|
def _reinitialize_config(base_config, new_kwargs):
|
||||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||||
|
@ -46,11 +46,7 @@ if is_torch_available():
|
|||||||
PersimmonForTokenClassification,
|
PersimmonForTokenClassification,
|
||||||
PersimmonModel,
|
PersimmonModel,
|
||||||
)
|
)
|
||||||
from transformers.models.persimmon.modeling_persimmon import (
|
from transformers.models.persimmon.modeling_persimmon import PersimmonRotaryEmbedding
|
||||||
PersimmonDynamicNTKScalingRotaryEmbedding,
|
|
||||||
PersimmonLinearScalingRotaryEmbedding,
|
|
||||||
PersimmonRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
|
||||||
@ -451,11 +447,12 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
# Sanity check linear RoPE scaling
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding(
|
linear_scaling_rope = PersimmonRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="linear",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
@ -469,11 +466,12 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
# Sanity check Dynamic NTK RoPE scaling
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding(
|
ntk_scaling_rope = PersimmonRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="dynamic",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
@ -42,11 +42,7 @@ if is_torch_available():
|
|||||||
PhiForTokenClassification,
|
PhiForTokenClassification,
|
||||||
PhiModel,
|
PhiModel,
|
||||||
)
|
)
|
||||||
from transformers.models.phi.modeling_phi import (
|
from transformers.models.phi.modeling_phi import PhiRotaryEmbedding
|
||||||
PhiDynamicNTKScalingRotaryEmbedding,
|
|
||||||
PhiLinearScalingRotaryEmbedding,
|
|
||||||
PhiRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PhiModelTester:
|
class PhiModelTester:
|
||||||
@ -430,11 +426,12 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
# Sanity check linear RoPE scaling
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
linear_scaling_rope = PhiLinearScalingRotaryEmbedding(
|
linear_scaling_rope = PhiRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="linear",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
@ -448,11 +445,12 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
# Sanity check Dynamic NTK RoPE scaling
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding(
|
ntk_scaling_rope = PhiRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="dynamic",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
@ -44,11 +44,7 @@ if is_torch_available():
|
|||||||
StableLmForTokenClassification,
|
StableLmForTokenClassification,
|
||||||
StableLmModel,
|
StableLmModel,
|
||||||
)
|
)
|
||||||
from transformers.models.stablelm.modeling_stablelm import (
|
from transformers.models.stablelm.modeling_stablelm import StableLmRotaryEmbedding
|
||||||
StableLmDynamicNTKScalingRotaryEmbedding,
|
|
||||||
StableLmLinearScalingRotaryEmbedding,
|
|
||||||
StableLmRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
|
# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
|
||||||
@ -436,11 +432,12 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
# Sanity check linear RoPE scaling
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
linear_scaling_rope = StableLmLinearScalingRotaryEmbedding(
|
linear_scaling_rope = StableLmRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="linear",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
@ -454,11 +451,12 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
# Sanity check Dynamic NTK RoPE scaling
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
ntk_scaling_rope = StableLmDynamicNTKScalingRotaryEmbedding(
|
ntk_scaling_rope = StableLmRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
|
rope_type="dynamic",
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
Loading…
Reference in New Issue
Block a user