mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Merge branch 'pooler_end_logits_fp16_fix' of https://github.com/slayton58/pytorch-transformers into pr/1284
This commit is contained in:
commit
c50783e388
@ -501,6 +501,9 @@ class PoolerEndLogits(nn.Module):
|
|||||||
x = self.dense_1(x).squeeze(-1)
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
if p_mask is not None:
|
if p_mask is not None:
|
||||||
|
if next(self.parameters()).dtype == torch.float16:
|
||||||
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
|
else:
|
||||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user