Merge branch 'pooler_end_logits_fp16_fix' of https://github.com/slayton58/pytorch-transformers into pr/1284

This commit is contained in:
thomwolf 2019-10-01 18:17:48 -04:00
commit c50783e388

View File

@ -501,6 +501,9 @@ class PoolerEndLogits(nn.Module):
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x