Cloned tensors after indexing in _compute_attn_output_with_global_indices (#13613)

Co-authored-by: Alessandro Suglia <asuglia@fb.com>
This commit is contained in:
Alessandro Suglia 2021-09-17 16:05:49 +01:00 committed by GitHub
parent ce32c69c0b
commit 19b7acdd61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -586,7 +586,7 @@ class LEDEncoderSelfAttention(nn.Module):
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
# compute attn output only global # compute attn output only global
attn_output_only_global = torch.matmul( attn_output_only_global = torch.matmul(
attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone()
).transpose(1, 2) ).transpose(1, 2)
# reshape attn probs # reshape attn probs

View File

@ -976,7 +976,7 @@ class LongformerSelfAttention(nn.Module):
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
# compute attn output only global # compute attn output only global
attn_output_only_global = torch.matmul( attn_output_only_global = torch.matmul(
attn_probs_only_global.transpose(1, 2), value_vectors_only_global.transpose(1, 2) attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone()
).transpose(1, 2) ).transpose(1, 2)
# reshape attn probs # reshape attn probs