add everything (#11651)

This commit is contained in:
Vasudev Gupta 2021-05-13 16:21:30 +05:30 committed by GitHub
parent 57b6a80de8
commit 6ee1a4fd3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 8 deletions

View File

@ -647,13 +647,13 @@ class BigBirdBlockSparseAttention(nn.Module):
[
to_mask[:, :, :, : 3 * to_block_size],
to_mask[:, :, :, -to_block_size:],
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
],
dim=3,
)
second_rand_pad = torch.cat(
[
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, 0],
],
dim=3,
@ -781,13 +781,13 @@ class BigBirdBlockSparseAttention(nn.Module):
[
to_mask[:, :, :, :to_block_size],
to_mask[:, :, :, -3 * to_block_size :],
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
],
dim=3,
)
second_last_rand_pad = torch.cat(
[
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, -1],
],
dim=3,

View File

@ -475,13 +475,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
[
to_mask[:, :, :, : 3 * to_block_size],
to_mask[:, :, :, -to_block_size:],
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
],
dim=3,
)
second_rand_pad = torch.cat(
[
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, 0],
],
dim=3,
@ -609,13 +609,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
[
to_mask[:, :, :, :to_block_size],
to_mask[:, :, :, -3 * to_block_size :],
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
],
dim=3,
)
second_last_rand_pad = torch.cat(
[
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
rand_mask[:, :, -1],
],
dim=3,