mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head))" (#22444)
Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head)) (#21627)"
This reverts commit bad8300837
.
This commit is contained in:
parent
8894b81742
commit
55dae94c0c
@ -172,7 +172,8 @@ class MultiHeadAttention(nn.Module):
|
||||
k, v = cache[self.layer_id]
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
|
||||
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
|
||||
|
||||
|
@ -176,7 +176,8 @@ class MultiHeadAttention(nn.Module):
|
||||
k, v = cache[self.layer_id]
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
|
||||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
|
||||
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user