mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Fixing this TransformerXL bool issue
This commit is contained in:
parent
0b52642d37
commit
38b79b5a63
@ -423,7 +423,8 @@ class MultiHeadAttn(nn.Module):
|
|||||||
# [qlen x klen x bsz x n_head]
|
# [qlen x klen x bsz x n_head]
|
||||||
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
|
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
|
||||||
attn_score.mul_(self.scale)
|
attn_score.mul_(self.scale)
|
||||||
if attn_mask is not None and attn_mask.any().item():
|
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||||
|
attn_mask = (attn_mask == 1) # Switch to bool
|
||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
||||||
elif attn_mask.dim() == 3:
|
elif attn_mask.dim() == 3:
|
||||||
@ -586,7 +587,8 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||||||
attn_score.mul_(self.scale)
|
attn_score.mul_(self.scale)
|
||||||
|
|
||||||
#### compute attention probability
|
#### compute attention probability
|
||||||
if attn_mask is not None and attn_mask.any().item():
|
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||||
|
attn_mask = (attn_mask == 1) # Switch to bool
|
||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_score = attn_score.float().masked_fill(
|
attn_score = attn_score.float().masked_fill(
|
||||||
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
|
attn_mask[None,:,:,None], -1e30).type_as(attn_score)
|
||||||
@ -680,7 +682,8 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||||||
attn_score.mul_(self.scale)
|
attn_score.mul_(self.scale)
|
||||||
|
|
||||||
#### compute attention probability
|
#### compute attention probability
|
||||||
if attn_mask is not None and attn_mask.any().item():
|
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||||
|
attn_mask = (attn_mask == 1) # Switch to bool
|
||||||
if attn_mask.dim() == 2:
|
if attn_mask.dim() == 2:
|
||||||
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
||||||
elif attn_mask.dim() == 3:
|
elif attn_mask.dim() == 3:
|
||||||
@ -1139,10 +1142,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
mask_shift_len = qlen
|
mask_shift_len = qlen
|
||||||
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
|
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
|
||||||
+ torch.tril(all_ones, -mask_shift_len)).bool()[:, :, None] # -1
|
+ torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
||||||
else:
|
else:
|
||||||
dec_attn_mask = torch.triu(
|
dec_attn_mask = torch.triu(
|
||||||
word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()[:,:,None]
|
word_emb.new_ones(qlen, klen), diagonal=1+mlen)[:,:,None]
|
||||||
|
|
||||||
hids = []
|
hids = []
|
||||||
attentions = []
|
attentions = []
|
||||||
|
Loading…
Reference in New Issue
Block a user