This commit is contained in:
ZhuBaohe 2020-06-09 00:28:50 +08:00 committed by GitHub
parent 37be3786cf
commit 4c7f564f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -153,12 +153,11 @@ class LongformerSelfAttention(nn.Module):
beginning_mask_2d = input_tensor.new_ones(w, w + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3))
seqlen = input_tensor.size(1)
beginning_input = input_tensor[:, :affected_seqlen, :, : w + 1]
beginning_mask = beginning_mask[:, :seqlen].expand(beginning_input.size())
beginning_mask = beginning_mask.expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
ending_input = input_tensor[:, -affected_seqlen:, :, -(w + 1) :]
ending_mask = ending_mask[:, -seqlen:].expand(ending_input.size())
ending_mask = ending_mask.expand(ending_input.size())
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
def _sliding_chunks_matmul_qk(self, q: torch.Tensor, k: torch.Tensor, w: int):
@ -301,7 +300,6 @@ class LongformerSelfAttention(nn.Module):
k = k.view(seqlen, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_weights = (batch_size, seqlen, num_heads, window*2+1)
attn_weights = self._sliding_chunks_matmul_qk(q, k, self.one_sided_attention_window_size)
self._mask_invalid_locations(attn_weights, self.one_sided_attention_window_size)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (batch_size x seqlen) to (batch_size x seqlen x num_heads x hidden_size)
@ -329,7 +327,7 @@ class LongformerSelfAttention(nn.Module):
selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
# (batch_size, seqlen, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum("blhd,bshd->blhs", (q, selected_k))
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000.0
# concat to attn_weights
# (batch_size, seqlen, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)