Merge pull request #1154 from ziliwang/master

fix: hard coding for max number
This commit is contained in:
Thomas Wolf 2019-08-30 23:23:08 +02:00 committed by GitHub
commit 206c35e9a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -418,7 +418,10 @@ class XLNetRelativeAttention(nn.Module):
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
attn_score = attn_score - 1e30 * attn_mask
if attn_mask.dtype == torch.float16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
# attention probability
attn_prob = F.softmax(attn_score, dim=1)