diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 3cfdee38cbe..c4ca0be8789 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -1142,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel): else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) - + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 + + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 else: dec_attn_mask = torch.triu( - word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] + word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] hids = [] attentions = []