mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Remove unnecessary unsqueeze - squeeze in rotary positional embedding (#26162)
* remove unnecessary unsqueeze-squeeze in llama * correct other models * fix * revert gpt_neox_japanese * fix copie * fix test
This commit is contained in:
parent
65aabafe2f
commit
64845307b3
@ -118,8 +118,8 @@ class OpenLlamaRotaryEmbedding(nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -127,8 +127,8 @@ class OpenLlamaRotaryEmbedding(nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -148,8 +148,8 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
|
||||
@ -175,8 +175,8 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@ -188,11 +188,8 @@ def rotate_half(x):
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
@ -115,8 +115,8 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
self.cos_cached = emb.cos()
|
||||
self.sin_cached = emb.sin()
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
@ -133,8 +133,8 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
self.sin_cached = self.sin_cached.to(device)
|
||||
|
||||
# Gather cos, sin at the designated position ids
|
||||
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||
cos = self.cos_cached[position_ids] # [bs, seq_len, dim]
|
||||
sin = self.sin_cached[position_ids] # [bs, seq_len, dim]
|
||||
return cos, sin
|
||||
|
||||
def forward(self, query, key, past_key_values_length, position_ids):
|
||||
@ -181,8 +181,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
self.cos_cached = emb.cos()
|
||||
self.sin_cached = emb.sin()
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
@ -215,8 +215,8 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
self.cos_cached = emb.cos()
|
||||
self.sin_cached = emb.sin()
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
@ -309,8 +309,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -318,8 +318,8 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -339,8 +339,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
|
||||
@ -366,8 +366,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@ -379,11 +379,8 @@ def rotate_half(x):
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
@ -261,8 +261,8 @@ class RotaryEmbedding(nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -270,8 +270,8 @@ class RotaryEmbedding(nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
|
@ -523,8 +523,8 @@ class IdeficsEmbedding(torch.nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -532,8 +532,8 @@ class IdeficsEmbedding(torch.nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -546,11 +546,8 @@ def rotate_half(x):
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
@ -138,8 +138,8 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -147,8 +147,8 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -167,8 +167,8 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
@ -193,8 +193,8 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@ -204,12 +204,10 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
@ -149,8 +149,8 @@ class MistralRotaryEmbedding(nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -158,8 +158,8 @@ class MistralRotaryEmbedding(nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -173,11 +173,8 @@ def rotate_half(x):
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
@ -94,8 +94,8 @@ class PersimmonRotaryEmbedding(nn.Module):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
@ -103,8 +103,8 @@ class PersimmonRotaryEmbedding(nn.Module):
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -124,8 +124,8 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
|
||||
@ -151,8 +151,8 @@ class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -165,11 +165,8 @@ def rotate_half(x):
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
@ -430,7 +430,8 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
|
||||
out = model(torch.tensor([input_ids])).logits
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
out = model(input_ids).logits.cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
|
Loading…
Reference in New Issue
Block a user