From 4c7f564f9a8c69430c55dce1b1f93c9e65d5944d Mon Sep 17 00:00:00 2001 From: ZhuBaohe Date: Tue, 9 Jun 2020 00:28:50 +0800 Subject: [PATCH] fix (#4839) --- src/transformers/modeling_longformer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index d61036c05c9..0b4e838873a 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -153,12 +153,11 @@ class LongformerSelfAttention(nn.Module): beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0]) beginning_mask = beginning_mask_2d[None, :, None, :] ending_mask = beginning_mask.flip(dims=(1, 3)) - seqlen = input_tensor.size(1) beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1] - beginning_mask = beginning_mask[:, :seqlen].expand(beginning_input.size()) + beginning_mask = beginning_mask.expand(beginning_input.size()) beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :] - ending_mask = ending_mask[:, -seqlen:].expand(ending_input.size()) + ending_mask = ending_mask.expand(ending_input.size()) ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int): @@ -301,7 +300,6 @@ class LongformerSelfAttention(nn.Module): k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights = (batch_size, seqlen, num_heads, window*2+1) attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size) - self._mask_invalid_locations(attn_weights, self.one_sided_attention_window_size) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 # from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size) @@ -329,7 +327,7 @@ class LongformerSelfAttention(nn.Module): selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] # (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch) selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k)) - selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 + selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0 # concat to attn_weights # (batch_size, seqlen, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)