mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix (#4839)
This commit is contained in:
parent
37be3786cf
commit
4c7f564f9a
@ -153,12 +153,11 @@ class LongformerSelfAttention(nn.Module):
|
||||
beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0])
|
||||
beginning_mask = beginning_mask_2d[None, :, None, :]
|
||||
ending_mask = beginning_mask.flip(dims=(1, 3))
|
||||
seqlen = input_tensor.size(1)
|
||||
beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1]
|
||||
beginning_mask = beginning_mask[:, :seqlen].expand(beginning_input.size())
|
||||
beginning_mask = beginning_mask.expand(beginning_input.size())
|
||||
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
|
||||
ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :]
|
||||
ending_mask = ending_mask[:, -seqlen:].expand(ending_input.size())
|
||||
ending_mask = ending_mask.expand(ending_input.size())
|
||||
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
|
||||
|
||||
def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int):
|
||||
@ -301,7 +300,6 @@ class LongformerSelfAttention(nn.Module):
|
||||
k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
|
||||
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
|
||||
attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size)
|
||||
self._mask_invalid_locations(attn_weights, self.one_sided_attention_window_size)
|
||||
if remove_from_windowed_attention_mask is not None:
|
||||
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
|
||||
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
|
||||
@ -329,7 +327,7 @@ class LongformerSelfAttention(nn.Module):
|
||||
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
|
||||
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
|
||||
selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k))
|
||||
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
|
||||
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0
|
||||
# concat to attn_weights
|
||||
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
|
||||
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user