From 53282b5bd0cf78fae913d1d7e43f94c94620df0c Mon Sep 17 00:00:00 2001 From: Nikolay Korolev Date: Tue, 27 Aug 2019 14:19:03 +0300 Subject: [PATCH] Change attention mask dtype to be bool. Fix #1119 --- pytorch_transformers/modeling_transfo_xl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 3cfdee38cbe..c4ca0be8789 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel): else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) - + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 + + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 else: dec_attn_mask = torch.triu( - word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] + word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] hids = [] attentions = []