mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 09:42:22 +06:00
* Update sqrt computation so it can survive a torch.jit.trace * Update modeling_gpt2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
9a0a8c1c6f
commit
c7d06b79ae
@ -17,7 +17,6 @@
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
@ -143,7 +142,7 @@ class Attention(nn.Module):
|
||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
w = w / (v.size(-1) ** 0.5)
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = torch.where(mask, w, self.masked_bias)
|
||||
|
Loading…
Reference in New Issue
Block a user