Change attention mask dtype to be bool. Fix #1119

This commit is contained in:
Nikolay Korolev 2019-08-27 14:19:03 +03:00 committed by GitHub
parent e08c01aa1a
commit 53282b5bd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
+ torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
hids = []
attentions = []