mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add everything (#11651)
This commit is contained in:
parent
57b6a80de8
commit
6ee1a4fd3e
@ -647,13 +647,13 @@ class BigBirdBlockSparseAttention(nn.Module):
|
||||
[
|
||||
to_mask[:, :, :, : 3 * to_block_size],
|
||||
to_mask[:, :, :, -to_block_size:],
|
||||
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
second_rand_pad = torch.cat(
|
||||
[
|
||||
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask[:, :, 0],
|
||||
],
|
||||
dim=3,
|
||||
@ -781,13 +781,13 @@ class BigBirdBlockSparseAttention(nn.Module):
|
||||
[
|
||||
to_mask[:, :, :, :to_block_size],
|
||||
to_mask[:, :, :, -3 * to_block_size :],
|
||||
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
second_last_rand_pad = torch.cat(
|
||||
[
|
||||
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask[:, :, -1],
|
||||
],
|
||||
dim=3,
|
||||
|
@ -475,13 +475,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
||||
[
|
||||
to_mask[:, :, :, : 3 * to_block_size],
|
||||
to_mask[:, :, :, -to_block_size:],
|
||||
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
second_rand_pad = torch.cat(
|
||||
[
|
||||
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask[:, :, 0],
|
||||
],
|
||||
dim=3,
|
||||
@ -609,13 +609,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
||||
[
|
||||
to_mask[:, :, :, :to_block_size],
|
||||
to_mask[:, :, :, -3 * to_block_size :],
|
||||
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
second_last_rand_pad = torch.cat(
|
||||
[
|
||||
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||
rand_mask[:, :, -1],
|
||||
],
|
||||
dim=3,
|
||||
|
Loading…
Reference in New Issue
Block a user