Fix #3954 - GPT2 is not traceable (#3955)

* Update sqrt computation so it can survive a torch.jit.trace

* Update modeling_gpt2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
jazzcook15 2020-04-28 12:18:56 -07:00 committed by GitHub
parent 9a0a8c1c6f
commit c7d06b79ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,6 @@
import logging
import math
import os
import torch
@ -143,7 +142,7 @@ class Attention(nn.Module):
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
w = w / (v.size(-1) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask, w, self.masked_bias)