fix bugs of places of "GroupNorm with scale" and etc

This commit is contained in:
weak-kajuma 2024-10-21 12:44:33 +00:00
parent 107bd3c340
commit 26307d92f6

View File

@ -314,7 +314,7 @@ class DiffLlamaAttention(nn.Module):
self.lambda_k1 = nn.Parameter(torch.normal(0, 0.1, size=(self.head_dim,)))
self.lambda_q2 = nn.Parameter(torch.normal(0, 0.1, size=(self.head_dim,)))
self.lambda_k2 = nn.Parameter(torch.normal(0, 0.1, size=(self.head_dim,)))
self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
self.groupnorm = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps, elementwise_affine=False)
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = DiffLlamaRotaryEmbedding(config=self.config)
@ -400,9 +400,9 @@ class DiffLlamaAttention(nn.Module):
attn_output2 = torch.matmul(attn_weights2, value_states)
attn_output = attn_output1 - lambda_full * attn_output2
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
@ -577,8 +577,9 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention):
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_output = attn_output1 - lambda_full * attn_output2
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
@ -715,9 +716,10 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention):
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_output = attn_output1 - lambda_full * attn_output2
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value