regenerate modeling

This commit is contained in:
JiwenJ 2025-04-20 07:19:34 +00:00
parent 269a49e5d6
commit 90ce1658c3

View File

@ -289,7 +289,12 @@ class PLMAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
batch_size, seq_length = hidden_states.shape[:-1]
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
key_shape = (
batch_size,
seq_length,
-1,
self.qk_nope_head_dim + self.v_head_dim,
)
if self.q_lora_rank is not None:
q_states = (
self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)