mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
* Update sqrt computation so it can survive a torch.jit.trace * Update modeling_gpt2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
9a0a8c1c6f
commit
c7d06b79ae
@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -143,7 +142,7 @@ class Attention(nn.Module):
|
|||||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
||||||
w = torch.matmul(q, k)
|
w = torch.matmul(q, k)
|
||||||
if self.scale:
|
if self.scale:
|
||||||
w = w / math.sqrt(v.size(-1))
|
w = w / (v.size(-1) ** 0.5)
|
||||||
nd, ns = w.size(-2), w.size(-1)
|
nd, ns = w.size(-2), w.size(-1)
|
||||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||||
w = torch.where(mask, w, self.masked_bias)
|
w = torch.where(mask, w, self.masked_bias)
|
||||||
|
Loading…
Reference in New Issue
Block a user