Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head))" (#22444)

Revert "Error (also in original) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head)) (#21627)"

This reverts commit bad8300837.
This commit is contained in:
Sylvain Gugger 2023-03-29 10:59:42 -04:00 committed by GitHub
parent 8894b81742
commit 55dae94c0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View File

@ -172,7 +172,8 @@ class MultiHeadAttention(nn.Module):
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)

View File

@ -176,7 +176,8 @@ class MultiHeadAttention(nn.Module):
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)