From c7d06b79ae8ea9c686f44ffdd6ab954ee6e10fd4 Mon Sep 17 00:00:00 2001 From: jazzcook15 <37391310+jazzcook15@users.noreply.github.com> Date: Tue, 28 Apr 2020 12:18:56 -0700 Subject: [PATCH] Fix #3954 - GPT2 is not traceable (#3955) * Update sqrt computation so it can survive a torch.jit.trace * Update modeling_gpt2.py Co-authored-by: Patrick von Platen --- src/transformers/modeling_gpt2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 120139964e3..35c3601ce8e 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -17,7 +17,6 @@ import logging -import math import os import torch @@ -143,7 +142,7 @@ class Attention(nn.Module): def _attn(self, q, k, v, attention_mask=None, head_mask=None): w = torch.matmul(q, k) if self.scale: - w = w / math.sqrt(v.size(-1)) + w = w / (v.size(-1) ** 0.5) nd, ns = w.size(-2), w.size(-1) mask = self.bias[:, :, ns - nd : ns, :ns] w = torch.where(mask, w, self.masked_bias)