mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
make fix copy (#11627)
This commit is contained in:
parent
dc3f6758cf
commit
e7bff0aabe
@ -1280,28 +1280,26 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
assert attn_weights.size() == (
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
assert attention_mask.size() == (
|
||||
bsz,
|
||||
1,
|
||||
tgt_len,
|
||||
src_len,
|
||||
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
assert layer_head_mask.size() == (
|
||||
self.num_heads,
|
||||
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
@ -1319,17 +1317,14 @@ class BigBirdPegasusDecoderAttention(nn.Module):
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
assert attn_output.size() == (
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
self.head_dim,
|
||||
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, tgt_len, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user