mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Merge pull request #1154 from ziliwang/master
fix: hard coding for max number
This commit is contained in:
commit
206c35e9a4
@ -418,6 +418,9 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
attn_score = (ac + bd + ef) * self.scale
|
attn_score = (ac + bd + ef) * self.scale
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
|
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
|
||||||
|
if attn_mask.dtype == torch.float16:
|
||||||
|
attn_score = attn_score - 65500 * attn_mask
|
||||||
|
else:
|
||||||
attn_score = attn_score - 1e30 * attn_mask
|
attn_score = attn_score - 1e30 * attn_mask
|
||||||
|
|
||||||
# attention probability
|
# attention probability
|
||||||
|
Loading…
Reference in New Issue
Block a user