mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bugs of places of "GroupNorm with scale" and etc
This commit is contained in:
parent
107bd3c340
commit
26307d92f6
@ -314,7 +314,7 @@ class DiffLlamaAttention(nn.Module):
|
||||
self.lambda_k1 = nn.Parameter(torch.normal(0, 0.1, size=(self.head_dim,)))
|
||||
self.lambda_q2 = nn.Parameter(torch.normal(0, 0.1, size=(self.head_dim,)))
|
||||
self.lambda_k2 = nn.Parameter(torch.normal(0, 0.1, size=(self.head_dim,)))
|
||||
self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
|
||||
self.groupnorm = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps, elementwise_affine=False)
|
||||
|
||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
||||
self.rotary_emb = DiffLlamaRotaryEmbedding(config=self.config)
|
||||
@ -400,9 +400,9 @@ class DiffLlamaAttention(nn.Module):
|
||||
attn_output2 = torch.matmul(attn_weights2, value_states)
|
||||
|
||||
attn_output = attn_output1 - lambda_full * attn_output2
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
@ -577,8 +577,9 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
attn_output = attn_output1 - lambda_full * attn_output2
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -715,9 +716,10 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention):
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
attn_output = attn_output1 - lambda_full * attn_output2
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
Loading…
Reference in New Issue
Block a user