mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Change attention mask dtype to be bool. Fix #1119
This commit is contained in:
parent
e08c01aa1a
commit
53282b5bd0
@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
mask_shift_len = qlen
|
mask_shift_len = qlen
|
||||||
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
|
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
|
||||||
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
|
+ torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
|
||||||
else:
|
else:
|
||||||
dec_attn_mask = torch.triu(
|
dec_attn_mask = torch.triu(
|
||||||
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
|
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
|
||||||
|
|
||||||
hids = []
|
hids = []
|
||||||
attentions = []
|
attentions = []
|
||||||
|
Loading…
Reference in New Issue
Block a user