mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
add everything (#11651)
This commit is contained in:
parent
57b6a80de8
commit
6ee1a4fd3e
@ -647,13 +647,13 @@ class BigBirdBlockSparseAttention(nn.Module):
|
|||||||
[
|
[
|
||||||
to_mask[:, :, :, : 3 * to_block_size],
|
to_mask[:, :, :, : 3 * to_block_size],
|
||||||
to_mask[:, :, :, -to_block_size:],
|
to_mask[:, :, :, -to_block_size:],
|
||||||
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
)
|
)
|
||||||
second_rand_pad = torch.cat(
|
second_rand_pad = torch.cat(
|
||||||
[
|
[
|
||||||
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||||
rand_mask[:, :, 0],
|
rand_mask[:, :, 0],
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
@ -781,13 +781,13 @@ class BigBirdBlockSparseAttention(nn.Module):
|
|||||||
[
|
[
|
||||||
to_mask[:, :, :, :to_block_size],
|
to_mask[:, :, :, :to_block_size],
|
||||||
to_mask[:, :, :, -3 * to_block_size :],
|
to_mask[:, :, :, -3 * to_block_size :],
|
||||||
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
)
|
)
|
||||||
second_last_rand_pad = torch.cat(
|
second_last_rand_pad = torch.cat(
|
||||||
[
|
[
|
||||||
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||||
rand_mask[:, :, -1],
|
rand_mask[:, :, -1],
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
|
@ -475,13 +475,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
|||||||
[
|
[
|
||||||
to_mask[:, :, :, : 3 * to_block_size],
|
to_mask[:, :, :, : 3 * to_block_size],
|
||||||
to_mask[:, :, :, -to_block_size:],
|
to_mask[:, :, :, -to_block_size:],
|
||||||
first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
)
|
)
|
||||||
second_rand_pad = torch.cat(
|
second_rand_pad = torch.cat(
|
||||||
[
|
[
|
||||||
first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||||
rand_mask[:, :, 0],
|
rand_mask[:, :, 0],
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
@ -609,13 +609,13 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
|||||||
[
|
[
|
||||||
to_mask[:, :, :, :to_block_size],
|
to_mask[:, :, :, :to_block_size],
|
||||||
to_mask[:, :, :, -3 * to_block_size :],
|
to_mask[:, :, :, -3 * to_block_size :],
|
||||||
context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
to_mask.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]),
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
)
|
)
|
||||||
second_last_rand_pad = torch.cat(
|
second_last_rand_pad = torch.cat(
|
||||||
[
|
[
|
||||||
context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
rand_mask.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]),
|
||||||
rand_mask[:, :, -1],
|
rand_mask[:, :, -1],
|
||||||
],
|
],
|
||||||
dim=3,
|
dim=3,
|
||||||
|
Loading…
Reference in New Issue
Block a user