diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 0c5c5b77983..e3343222903 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -423,7 +423,8 @@ class MultiHeadAttn(nn.Module): # [qlen x klen x bsz x n_head] attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) attn_score.mul_(self.scale) - if attn_mask is not None and attn_mask.any().item(): + if attn_mask is not None and torch.sum(attn_mask).item(): + attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) elif attn_mask.dim() == 3: @@ -586,7 +587,8 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): attn_score.mul_(self.scale) #### compute attention probability - if attn_mask is not None and attn_mask.any().item(): + if attn_mask is not None and torch.sum(attn_mask).item(): + attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: attn_score = attn_score.float().masked_fill( attn_mask[None,:,:,None], -1e30).type_as(attn_score) @@ -680,7 +682,8 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): attn_score.mul_(self.scale) #### compute attention probability - if attn_mask is not None and attn_mask.any().item(): + if attn_mask is not None and torch.sum(attn_mask).item(): + attn_mask = (attn_mask == 1) # Switch to bool if attn_mask.dim() == 2: attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) elif attn_mask.dim() == 3: @@ -1139,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel): else: mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1+mlen) - + torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 + + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 else: dec_attn_mask = torch.triu( - word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] + word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None] hids = [] attentions = []