Fixing this TransformerXL bool issue

This commit is contained in:
thomwolf 2019-09-04 22:36:30 +02:00
parent 0b52642d37
commit 38b79b5a63

View File

@ -423,7 +423,8 @@ class MultiHeadAttn(nn.Module):
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
@ -586,7 +587,8 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
#### compute attention probability #### compute attention probability
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill( attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -1e30).type_as(attn_score) attn_mask[None,:,:,None], -1e30).type_as(attn_score)
@ -680,7 +682,8 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_score.mul_(self.scale) attn_score.mul_(self.scale)
#### compute attention probability #### compute attention probability
if attn_mask is not None and attn_mask.any().item(): if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = (attn_mask == 1) # Switch to bool
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
@ -1139,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else: else:
mask_shift_len = qlen mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen) dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1 + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else: else:
dec_attn_mask = torch.triu( dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None] word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
hids = [] hids = []
attentions = [] attentions = []