diff --git a/docs/source/model_doc/longformer.rst b/docs/source/model_doc/longformer.rst index c2d44a60b07..ea03b798cef 100644 --- a/docs/source/model_doc/longformer.rst +++ b/docs/source/model_doc/longformer.rst @@ -102,3 +102,25 @@ LongformerForQuestionAnswering .. autoclass:: transformers.LongformerForQuestionAnswering :members: + + +TFLongformerModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFLongformerModel + :members: + + +TFLongformerForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFLongformerForMaskedLM + :members: + + +TFLongformerForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFLongformerForQuestionAnswering + :members: + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c43e19604f2..5df76248fe9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -399,6 +399,7 @@ if is_torch_available(): LongformerForMultipleChoice, LongformerForTokenClassification, LongformerForQuestionAnswering, + LongformerSelfAttention, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ) @@ -568,6 +569,14 @@ if is_tf_available(): TFGPT2PreTrainedModel, ) + from .modeling_tf_longformer import ( + TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFLongformerModel, + TFLongformerForMaskedLM, + TFLongformerForQuestionAnswering, + TFLongformerSelfAttention, + ) + from .modeling_tf_mobilebert import ( TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFMobileBertModel, diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index da422b95f84..033b8bbbcc0 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -71,7 +71,6 @@ def _get_question_end_index(input_ids, sep_token_id): assert ( sep_token_indices.shape[0] == 3 * batch_size ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error." - return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] @@ -81,7 +80,6 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru before `sep_token_id` if `before_sep_token is True` else after `sep_token_id`. """ - question_end_index = _get_question_end_index(input_ids, sep_token_id) question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 # bool attention mask with True in locations of global attention @@ -131,6 +129,172 @@ class LongformerSelfAttention(nn.Module): self.one_sided_attn_window_size = attention_window // 2 + def forward( + self, hidden_states, attention_mask=None, output_attentions=False, + ): + """ + LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. + Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer. + + The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to + -ve: no attention + 0: local attention + +ve: global attention + + """ + attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) + + # is index masked or global attention + is_index_masked = attention_mask < 0 + is_index_global_attn = attention_mask > 0 + is_global_attn = is_index_global_attn.flatten().any().item() + + hidden_states = hidden_states.transpose(0, 1) + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = hidden_states.size() + assert ( + embed_dim == self.embed_dim + ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" + + # normalize query + query_vectors /= math.sqrt(self.head_dim) + + query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None] + + # cast to fp32/fp16 then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( + remove_from_windowed_attention_mask, -10000.0 + ) + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + assert list(attn_scores.size()) == [ + batch_size, + seq_len, + self.num_heads, + self.one_sided_attn_window_size * 2 + 1, + ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" + + # compute local attention probs from global attention keys and contact over window dim + if is_global_attn: + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + # calculate global attn probs from global key + + global_key_attn_scores = self._concat_with_global_key_attn_probs( + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ) + # concat to attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) + + # free memory + del global_key_attn_scores + + attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_probs = attn_probs_fp32.type_as(attn_scores) + + # free memory + del attn_probs_fp32 + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0) + + # apply dropout + attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) + + value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) + + # compute local attention output with global attention value and add + if is_global_attn: + # compute sum of global and local attn + attn_output = self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ) + else: + # compute local attn only + attn_output = self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ) + + assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" + attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() + + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + if is_global_attn: + global_attn_output = self._compute_global_attn_output_from_hidden( + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + ) + + # get only non zero global attn output + nonzero_global_attn_output = global_attn_output[ + is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] + ] + + # overwrite values with global attention + attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( + len(is_local_index_global_attn_nonzero[0]), -1 + ) + + attn_output = attn_output.transpose(0, 1) + + if output_attentions: + if is_global_attn: + # With global attention, return global attention probabilities only + # batch_size x num_heads x max_num_global_attention_tokens x sequence_length + # which is the attention weights from tokens with global attention to all tokens + # It doesn't not return local attention + # In case of variable number of global attantion in the rows of a batch, + # attn_probs are padded with -10000.0 attention scores + attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + else: + # without global attention, return local attention probabilities + # batch_size x num_heads x sequence_length x window_size + # which is the attention weights of every token attending to its neighbours + attn_probs = attn_probs.permute(0, 2, 1, 3) + + outputs = (attn_output, attn_probs) if output_attentions else (attn_output,) + return outputs + @staticmethod def _pad_and_transpose_last_two_dims(hidden_states_padded, padding): """pads rows and then flips rows and columns""" @@ -143,8 +307,20 @@ class LongformerSelfAttention(nn.Module): return hidden_states_padded @staticmethod - def _pad_by_window_overlap_except_last_row(chunked_hidden_states): - """shift every row 1 step right, converting columns into diagonals""" + def _pad_and_diagonalize(chunked_hidden_states): + """shift every row 1 step right, converting columns into diagonals. + Example: + chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, + -1.8348, 0.7672, 0.2986, 0.0285, + -0.7584, 0.4206, -0.0405, 0.1599, + 2.0514, -1.1600, 0.5372, 0.2629 ] + window_overlap = num_rows = 4 + (pad & diagonilize) => + [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 + 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 + 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() chunked_hidden_states = F.pad( chunked_hidden_states, (0, window_overlap + 1) @@ -181,7 +357,8 @@ class LongformerSelfAttention(nn.Module): chunk_stride[1] = chunk_stride[1] // 2 return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor: + @staticmethod + def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) beginning_mask = beginning_mask_2d[None, :, None, :] ending_mask = beginning_mask.flip(dims=(1, 3)) @@ -243,6 +420,7 @@ class LongformerSelfAttention(nn.Module): diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ :, :, -(window_overlap + 1) : -1, window_overlap + 1 : ] + diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ :, 0, : window_overlap - 1, 1 - window_overlap : ] @@ -261,11 +439,13 @@ class LongformerSelfAttention(nn.Module): """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the same shape as `attn_probs`""" batch_size, seq_len, num_heads, head_dim = value.size() + assert seq_len % (window_overlap * 2) == 0 assert attn_probs.size()[:3] == value.size()[:3] assert attn_probs.size(3) == 2 * window_overlap + 1 chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + chunked_attn_probs = attn_probs.transpose(1, 2).reshape( batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 ) @@ -287,178 +467,11 @@ class LongformerSelfAttention(nn.Module): ) chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) - chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs) + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2) - def forward( - self, hidden_states, attention_mask=None, output_attentions=False, - ): - """ - LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. - Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer. - - The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to - -ve: no attention - 0: local attention - +ve: global attention - - """ - - attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) - - # is index masked or global attention - is_index_masked = attention_mask < 0 - is_index_global_attn = attention_mask > 0 - is_global_attn = is_index_global_attn.flatten().any().item() - - hidden_states = hidden_states.transpose(0, 1) - - # project hidden states - query_vectors = self.query(hidden_states) - key_vectors = self.key(hidden_states) - value_vectors = self.value(hidden_states) - - seq_len, batch_size, embed_dim = hidden_states.size() - assert ( - embed_dim == self.embed_dim - ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}" - - # normalize query - query_vectors /= math.sqrt(self.head_dim) - - query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - - # attn_probs = (batch_size, seq_len, num_heads, window*2+1) - attn_scores = self._sliding_chunks_query_key_matmul( - query_vectors, key_vectors, self.one_sided_attn_window_size - ) - - # values to pad for attention probs - remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1) - - # cast to fp32/fp16 then replace 1's with -inf - float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( - remove_from_windowed_attention_mask, -10000.0 - ) - # diagonal mask with zeros everywhere and -inf inplace of padding - diagonal_mask = self._sliding_chunks_query_key_matmul( - float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size - ) - - # pad local attention probs - attn_scores += diagonal_mask - - assert list(attn_scores.size()) == [ - batch_size, - seq_len, - self.num_heads, - self.one_sided_attn_window_size * 2 + 1, - ], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}" - - # compute local attention probs from global attention keys and contact over window dim - if is_global_attn: - # compute global attn indices required through out forward fn - ( - max_num_global_attn_indices, - is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero, - ) = self._get_global_attn_indices(is_index_global_attn) - # calculate global attn probs from global key - global_key_attn_scores = self._concat_with_global_key_attn_probs( - query_vectors=query_vectors, - key_vectors=key_vectors, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - ) - # concat to attn_probs - # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) - attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1) - - # free memory - del global_key_attn_scores - - attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability - attn_probs = attn_probs_fp32.type_as(attn_scores) - - # free memory - del attn_probs_fp32 - - # softmax sometimes inserts NaN if all positions are masked, replace them with 0 - attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0) - - # apply dropout - attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training) - - value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1) - - # compute local attention output with global attention value and add - if is_global_attn: - # compute sum of global and local attn - attn_output = self._compute_attn_output_with_global_indices( - value_vectors=value_vectors, - attn_probs=attn_probs, - max_num_global_attn_indices=max_num_global_attn_indices, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - ) - else: - # compute local attn only - attn_output = self._sliding_chunks_matmul_attn_probs_value( - attn_probs, value_vectors, self.one_sided_attn_window_size - ) - - assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size" - attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous() - - # compute value for global attention and overwrite to attention output - # TODO: remove the redundant computation - if is_global_attn: - global_attn_output = self._compute_global_attn_output_from_hidden( - hidden_states=hidden_states, - max_num_global_attn_indices=max_num_global_attn_indices, - is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, - is_index_global_attn_nonzero=is_index_global_attn_nonzero, - is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, - is_index_masked=is_index_masked, - ) - - # get only non zero global attn output - nonzero_global_attn_output = global_attn_output[ - is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1] - ] - # overwrite values with global attention - attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view( - len(is_local_index_global_attn_nonzero[0]), -1 - ) - - attn_output = attn_output.transpose(0, 1) - - if output_attentions: - if is_global_attn: - # With global attention, return global attention probabilities only - # batch_size x num_heads x sequence_length x window_size - # which is the attention weights from all tokens to all tokens for global attention - # It doesn't not return local attention. Only tokens with global attention have values > 0.0 - attn_probs = attn_probs[:, :, :, :max_num_global_attn_indices] - # pad attn_probs to max length with 0.0 since global attn did not attend there - window_size = self.one_sided_attn_window_size * 2 + 1 - attn_probs = F.pad(attn_probs, (0, window_size - max_num_global_attn_indices), value=0.0,) - attn_probs = attn_probs.permute(0, 2, 1, 3) - else: - # without global attention, return local attention probabilities - # batch_size x num_heads x sequence_length x window_size - # which is the attention weights of every token attending to its neighbours - attn_probs = attn_probs.permute(0, 2, 1, 3) - - outputs = (attn_output, attn_probs) if output_attentions else (attn_output,) - return outputs - @staticmethod def _get_global_attn_indices(is_index_global_attn): """ compute global attn indices required throughout forward pass """ @@ -503,12 +516,16 @@ class LongformerSelfAttention(nn.Module): key_vectors_only_global = key_vectors.new_zeros( batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim ) + key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero] + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + attn_probs_from_global_key[ is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] ] = -10000.0 + return attn_probs_from_global_key def _compute_attn_output_with_global_indices( @@ -600,7 +617,7 @@ class LongformerSelfAttention(nn.Module): is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : ] = -10000.0 - global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,) + global_attn_scores = global_attn_scores.masked_fill(is_index_masked[:, None, None, :], -10000.0,) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) @@ -754,7 +771,6 @@ class LongformerPreTrainedModel(PreTrainedModel): LONGFORMER_START_DOCSTRING = r""" - This model is a PyTorch `torch.nn.Module `__ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. @@ -823,13 +839,13 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ) class LongformerModel(LongformerPreTrainedModel): """ - This class overrides :class:`~transformers.RobertaModel` to provide the ability to process - long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer - `__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer selfattention + This class copied code from :class:`~transformers.RobertaModel` and overwrote standard self-attention with longformer self-attention to provide the ability to process + long sequences following the self-attention approach described in `Longformer: the Long-Document Transformer + `__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer self-attention combines a local (sliding window) and global attention to extend to long documents without the O(n^2) increase in memory and compute. - The selfattention module `LongformerSelfAttention` implemented here supports the combination of local and + The self-attention module `LongformerSelfAttention` implemented here supports the combination of local and global attention but it lacks support for autoregressive attention and dilated attention. Autoregressive and dilated attention are more relevant for autoregressive language modeling than finetuning on downstream tasks. Future release will add support for autoregressive attention, but the support for dilated attention @@ -883,7 +899,7 @@ class LongformerModel(LongformerPreTrainedModel): inputs_embeds: torch.Tensor, pad_token_id: int, ): - """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + """A helper function to pad tokens and mask to work with implementation of Longformer self-attention.""" # padding attention_window = ( self.config.attention_window @@ -1053,6 +1069,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): self.init_weights() + def get_output_embeddings(self): + return self.lm_head.decoder + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1315,11 +1334,14 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # set global attention on question tokens if global_attention_mask is None: - logger.info("Initializing global attention on question tokens...") - # put global attention on all tokens until `config.sep_token_id` is reached - global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) + if input_ids is None: + logger.warning( + "It is not possible to automatically generate the `global_attention_mask` because input_ids is None. Please make sure that it is correctly set." + ) + else: + # set global attention on question tokens automatically + global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) outputs = self.longformer( input_ids, @@ -1504,7 +1526,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # set global attention on question tokens - if global_attention_mask is None: + if global_attention_mask is None and input_ids is not None: logger.info("Initializing global attention on multiple choice...") # put global attention on all tokens after `config.sep_token_id` global_attention_mask = torch.stack( diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index 66176a0e316..fa0f6961fe3 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -29,6 +29,7 @@ from .configuration_auto import ( ElectraConfig, FlaubertConfig, GPT2Config, + LongformerConfig, MobileBertConfig, OpenAIGPTConfig, RobertaConfig, @@ -93,6 +94,7 @@ from .modeling_tf_flaubert import ( TFFlaubertWithLMHeadModel, ) from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model +from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel from .modeling_tf_mobilebert import ( TFMobileBertForMaskedLM, TFMobileBertForMultipleChoice, @@ -149,6 +151,7 @@ TF_MODEL_MAPPING = OrderedDict( (AlbertConfig, TFAlbertModel), (CamembertConfig, TFCamembertModel), (XLMRobertaConfig, TFXLMRobertaModel), + (LongformerConfig, TFLongformerModel), (RobertaConfig, TFRobertaModel), (BertConfig, TFBertModel), (OpenAIGPTConfig, TFOpenAIGPTModel), @@ -191,6 +194,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( (AlbertConfig, TFAlbertForMaskedLM), (CamembertConfig, TFCamembertForMaskedLM), (XLMRobertaConfig, TFXLMRobertaForMaskedLM), + (LongformerConfig, TFLongformerForMaskedLM), (RobertaConfig, TFRobertaForMaskedLM), (BertConfig, TFBertForMaskedLM), (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), @@ -226,6 +230,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( (AlbertConfig, TFAlbertForMaskedLM), (CamembertConfig, TFCamembertForMaskedLM), (XLMRobertaConfig, TFXLMRobertaForMaskedLM), + (LongformerConfig, TFLongformerForMaskedLM), (RobertaConfig, TFRobertaForMaskedLM), (BertConfig, TFBertForMaskedLM), (MobileBertConfig, TFMobileBertForMaskedLM), @@ -259,6 +264,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( (AlbertConfig, TFAlbertForQuestionAnswering), (CamembertConfig, TFCamembertForQuestionAnswering), (XLMRobertaConfig, TFXLMRobertaForQuestionAnswering), + (LongformerConfig, TFLongformerForQuestionAnswering), (RobertaConfig, TFRobertaForQuestionAnswering), (BertConfig, TFBertForQuestionAnswering), (XLNetConfig, TFXLNetForQuestionAnsweringSimple), diff --git a/src/transformers/modeling_tf_longformer.py b/src/transformers/modeling_tf_longformer.py new file mode 100644 index 00000000000..3d3691c2661 --- /dev/null +++ b/src/transformers/modeling_tf_longformer.py @@ -0,0 +1,1392 @@ +# coding=utf-8 +# Copyright 2020 The Allen Institute for AI team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow Longformer model. """ + +import logging + +import tensorflow as tf + +from .configuration_longformer import LongformerConfig +from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_tf_bert import TFBertIntermediate, TFBertOutput, TFBertPooler, TFBertSelfOutput +from .modeling_tf_outputs import TFBaseModelOutputWithPooling, TFMaskedLMOutput, TFQuestionAnsweringModelOutput +from .modeling_tf_roberta import TFRobertaEmbeddings, TFRobertaLMHead +from .modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + cast_bool_to_primitive, + get_initializer, + keras_serializable, + shape_list, +) +from .tokenization_utils import BatchEncoding + + +logger = logging.getLogger(__name__) + +_CONFIG_FOR_DOC = "LongformerConfig" +_TOKENIZER_FOR_DOC = "LongformerTokenizer" + +TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "allenai/longformer-base-4096", + "allenai/longformer-large-4096", + "allenai/longformer-large-4096-finetuned-triviaqa", + "allenai/longformer-base-4096-extra.pos.embd.only", + "allenai/longformer-large-4096-extra.pos.embd.only", + # See all Longformer models at https://huggingface.co/models?filter=longformer +] + + +def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True): + """ + Computes global attention mask by putting attention on all tokens + before `sep_token_id` if `before_sep_token is True` else after + `sep_token_id`. + """ + + assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" + question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1] + question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1 + # bool attention mask with True in locations of global attention + attention_mask = tf.range(input_ids_shape[1]) + if before_sep_token is True: + attention_mask = tf.cast( + tf.broadcast_to(attention_mask, input_ids_shape) < tf.broadcast_to(question_end_index, input_ids_shape), + tf.dtypes.int32, + ) + else: + # last token is separation token and should not be counted and in the middle are two separation tokens + attention_mask = tf.cast( + tf.broadcast_to(attention_mask, input_ids_shape) + > tf.broadcast_to(question_end_index + 1, input_ids_shape), + tf.dtypes.int32, + ) * tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32) + + return attention_mask + + +class TFLongformerSelfAttention(tf.keras.layers.Layer): + def __init__(self, config, layer_id, **kwargs): + super().__init__(**kwargs) + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + + self.query = tf.keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = tf.keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = tf.keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + + # separate projection layers for tokens with global attention + self.query_global = tf.keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="query_global" + ) + self.key_global = tf.keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="key_global" + ) + self.value_global = tf.keras.layers.Dense( + self.embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="value_global" + ) + + self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) + self.global_dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) + + self.layer_id = layer_id + + attention_window = config.attention_window[self.layer_id] + assert ( + attention_window % 2 == 0 + ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}" + assert ( + attention_window > 0 + ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}" + + self.one_sided_attn_window_size = attention_window // 2 + + def call( + self, inputs, training=False, + ): + """ + LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. + Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer. + + The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to + -ve: no attention + 0: local attention + +ve: global attention + + """ + # retrieve input args + hidden_states, attention_mask, output_attentions = inputs + + attention_mask = tf.squeeze(tf.squeeze(attention_mask, axis=2), axis=1) + # is index masked or global attention + + is_index_masked = tf.math.less(attention_mask, 0) + is_index_global_attn = tf.math.greater(attention_mask, 0) + is_global_attn = tf.math.reduce_any(is_index_global_attn) + + hidden_states = tf.transpose(hidden_states, (1, 0, 2)) + + # project hidden states + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) + + seq_len, batch_size, embed_dim = shape_list(hidden_states) + tf.debugging.assert_equal( + embed_dim, + self.embed_dim, + message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}", + ) + + # normalize query + query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) + + query_vectors = tf.transpose( + tf.reshape(query_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) + ) + key_vectors = tf.transpose( + tf.reshape(key_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) + ) + + # attn_probs = (batch_size, seq_len, num_heads, window*2+1) + attn_scores = self._sliding_chunks_query_key_matmul( + query_vectors, key_vectors, self.one_sided_attn_window_size + ) + + # values to pad for attention probs + float_mask = tf.cast((attention_mask != 0)[:, :, None, None], dtype=tf.float32) * -10000.0 + + # diagonal mask with zeros everywhere and -inf inplace of padding + diagonal_mask = self._sliding_chunks_query_key_matmul( + tf.ones(shape_list(float_mask), dtype=tf.float32), float_mask, self.one_sided_attn_window_size + ) + + # pad local attention probs + attn_scores += diagonal_mask + + tf.debugging.assert_equal( + shape_list(attn_scores), + [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1], + message=f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}", + ) + + # compute global attn indices required through out forward fn + ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) = self._get_global_attn_indices(is_index_global_attn) + + # this function is only relevant for global attention + + attn_scores = tf.cond( + is_global_attn, + lambda: self._concat_with_global_key_attn_probs( + attn_scores=attn_scores, + query_vectors=query_vectors, + key_vectors=key_vectors, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + ), + lambda: attn_scores, + ) + + attn_probs = tf.nn.softmax(attn_scores, axis=-1) + + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_probs = tf.where( + tf.broadcast_to(is_index_masked[:, :, None, None], shape_list(attn_probs)), 0.0, attn_probs + ) + + # apply dropout + attn_probs = self.dropout(attn_probs, training=training) + + value_vectors = tf.transpose( + tf.reshape(value_vectors, (seq_len, batch_size, self.num_heads, self.head_dim)), (1, 0, 2, 3) + ) + + # if global attention, compute sum of global and local attn + attn_output = tf.cond( + is_global_attn, + lambda: self._compute_attn_output_with_global_indices( + value_vectors=value_vectors, + attn_probs=attn_probs, + max_num_global_attn_indices=max_num_global_attn_indices, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + ), + lambda: self._sliding_chunks_matmul_attn_probs_value( + attn_probs, value_vectors, self.one_sided_attn_window_size + ), + ) + + tf.debugging.assert_equal( + shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size" + ) + attn_output = tf.reshape(tf.transpose(attn_output, (1, 0, 2, 3)), (seq_len, batch_size, embed_dim)) + + # compute value for global attention and overwrite to attention output + # TODO: remove the redundant computation + attn_output = tf.cond( + is_global_attn, + lambda: self._compute_global_attn_output_from_hidden( + attn_output=attn_output, + hidden_states=hidden_states, + max_num_global_attn_indices=max_num_global_attn_indices, + is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero=is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, + is_index_masked=is_index_masked, + training=training, + ), + lambda: attn_output, + ) + + attn_output = tf.transpose(attn_output, (1, 0, 2)) + + # GLOBAL ATTN: + # With global attention, return global attention probabilities only + # batch_size x num_heads x max_num_global_attention_tokens x sequence_length + # which is the attention weights from tokens with global attention to all tokens + # It doesn't not return local attention + # In case of variable number of global attantion in the rows of a batch, + # attn_probs are padded with -10000.0 attention scores + # LOCAL ATTN: + # without global attention, return local attention probabilities + # batch_size x num_heads x sequence_length x window_size + # which is the attention weights of every token attending to its neighbours + attn_probs = tf.cond( + is_global_attn, + lambda: self._get_global_attn_probs(attn_probs, max_num_global_attn_indices), + lambda: tf.transpose(attn_probs, (0, 2, 1, 3)), + ) + + outputs = (attn_output, attn_probs) + return outputs + + @staticmethod + def _get_global_attn_probs(attn_probs, max_num_global_attn_indices): + # pad attn_probs to max length with 0.0 since global attn did not attend there + attn_probs = tf.concat( + [ + attn_probs[:, :, :, :max_num_global_attn_indices], + tf.zeros_like(attn_probs)[:, :, :, max_num_global_attn_indices:], + ], + axis=-1, + ) + attn_probs = tf.transpose(attn_probs, (0, 2, 1, 3)) + return attn_probs + + def _sliding_chunks_query_key_matmul(self, query, key, window_overlap): + """Matrix multiplication of query and key tensors using with a sliding window attention pattern. + This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) + with an overlap of size window_overlap""" + batch_size, seq_len, num_heads, head_dim = shape_list(query) + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), + 0, + message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}", + ) + tf.debugging.assert_equal( + shape_list(query), + shape_list(key), + message=f"Shape of query and key should be equal, but got query: {shape_list(query)} and key: {shape_list(key)}", + ) + + chunks_count = seq_len // window_overlap - 1 + + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 + query = tf.reshape(tf.transpose(query, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + + chunked_query = self._chunk(query, window_overlap) + chunked_key = self._chunk(key, window_overlap) + + # matrix multipication + # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim + # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap + chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply + + # convert diagonals into columns + paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings) + + # allocate space for the overall attention matrix where the chunks are combined. The last dimension + # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to + # window_overlap previous words). The following column is attention score from each word to itself, then + # followed by window_overlap columns for the upper triangle. + + # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions + # - copying the main diagonal and the upper triangle + # TODO: This code is most likely not very efficient and should be improved + diagonal_attn_scores_up_triang = tf.concat( + [ + diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1], + diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1], + ], + axis=1, + ) + + # - copying the lower triangle + diagonal_attn_scores_low_triang = tf.concat( + [ + tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), + diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :], + ], + axis=1, + ) + diagonal_attn_scores_first_chunk = tf.concat( + [ + tf.roll(diagonal_chunked_attention_scores, shift=[1, window_overlap], axis=[2, 3])[ + :, :, :window_overlap, :window_overlap + ], + tf.zeros((batch_size * num_heads, 1, window_overlap, window_overlap)), + ], + axis=1, + ) + + first_chunk_mask = ( + tf.broadcast_to( + tf.range(chunks_count + 1)[None, :, None, None], + shape=(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap), + ) + < 1 + ) + + diagonal_attn_scores_low_triang = tf.where( + first_chunk_mask, diagonal_attn_scores_first_chunk, diagonal_attn_scores_low_triang + ) + + # merging upper and lower triangle + diagonal_attention_scores = tf.concat( + [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1 + ) + + # separate batch_size and num_heads dimensions again + diagonal_attention_scores = tf.transpose( + tf.reshape(diagonal_attention_scores, (batch_size, num_heads, seq_len, 2 * window_overlap + 1)), + (0, 2, 1, 3), + ) + + diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap) + return diagonal_attention_scores + + @staticmethod + def _mask_invalid_locations(input_tensor, window_overlap): + # create correct upper triangle bool mask + mask_2d_upper = tf.reverse( + tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0), axis=[0] + ) + # pad to full matrix + padding = tf.constant( + [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]] + ) + + # create lower mask + mask_2d = tf.pad(mask_2d_upper, padding) + # combine with upper mask + mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1]) + + # broadcast to full matrix + mask_4d = tf.broadcast_to(mask_2d[None, :, None, :], shape_list(input_tensor)) + + # inf tensor used for masking + inf_tensor = -float("inf") * tf.ones_like(input_tensor, dtype=tf.dtypes.float32) + + # mask + input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor) + + return input_tensor + + def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap): + + """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. + Returned tensor will be of the same shape as `attn_probs`""" + + batch_size, seq_len, num_heads, head_dim = shape_list(value) + + tf.debugging.assert_equal( + seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap" + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[:3], + shape_list(value)[:3], + message="value and attn_probs must have same dims (except head_dim)", + ) + tf.debugging.assert_equal( + shape_list(attn_probs)[3], + 2 * window_overlap + 1, + message="attn_probs last dim has to be 2 * window_overlap + 1", + ) + + chunks_count = seq_len // window_overlap - 1 + # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap + + chunked_attn_probs = tf.reshape( + tf.transpose(attn_probs, (0, 2, 1, 3)), + (batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1), + ) + + # group batch_size and num_heads dimensions into one + value = tf.reshape(tf.transpose(value, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim)) + + # pad seq_len with w at the beginning of the sequence and another window overlap at the end + + paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32) + padded_value = tf.pad(value, paddings, constant_values=-1) + + # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap + + frame_size = 3 * window_overlap * head_dim + frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count + + chunked_value = tf.signal.frame( + tf.reshape(padded_value, (batch_size * num_heads, -1)), frame_size, frame_hop_size + ) + chunked_value = tf.reshape( + chunked_value, (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) + ) + + tf.debugging.assert_equal( + shape_list(chunked_value), + [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim], + message="Chunked value has the wrong shape", + ) + + chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) + + context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value) + context = tf.transpose(tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)), (0, 2, 1, 3)) + return context + + @staticmethod + def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings): + """pads rows and then flips rows and columns""" + hidden_states_padded = tf.pad( + hidden_states_padded, paddings + ) # padding value is not important because it will be overwritten + if len(shape_list(hidden_states_padded)) > 3: + batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length)) + else: + batch_size, seq_length, hidden_dim = shape_list(hidden_states_padded) + hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, hidden_dim, seq_length)) + return hidden_states_padded + + @staticmethod + def _pad_and_diagonalize(chunked_hidden_states): + """shift every row 1 step right, converting columns into diagonals. + Example: + chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, + -1.8348, 0.7672, 0.2986, 0.0285, + -0.7584, 0.4206, -0.0405, 0.1599, + 2.0514, -1.1600, 0.5372, 0.2629 ] + window_overlap = num_rows = 4 + (pad & diagonilize) => + [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 + 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 + 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 + 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] + """ + total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states) + + paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]]) + chunked_hidden_states = tf.pad( + chunked_hidden_states, paddings + ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten + + chunked_hidden_states = tf.reshape( + chunked_hidden_states, (total_num_heads, num_chunks, -1) + ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap + chunked_hidden_states = chunked_hidden_states[ + :, :, :-window_overlap + ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap + chunked_hidden_states = tf.reshape( + chunked_hidden_states, (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) + ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap + chunked_hidden_states = chunked_hidden_states[:, :, :, :-1] + return chunked_hidden_states + + @staticmethod + def _chunk(hidden_states, window_overlap): + """convert into overlapping chunkings. Chunk size = 2w, overlap size = w""" + batch_size, seq_length, hidden_dim = shape_list(hidden_states) + num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1 + + # define frame size and frame stride (similar to convolution) + frame_hop_size = window_overlap * hidden_dim + frame_size = 2 * frame_hop_size + + hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim)) + + # chunk with overlap + chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size) + + tf.debugging.assert_equal( + shape_list(chunked_hidden_states), + [batch_size, num_output_chunks, frame_size], + message=f"Make sure chunking is correctly applied. `Chunked hidden states should have output dimension {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}.", + ) + + chunked_hidden_states = tf.reshape( + chunked_hidden_states, (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim) + ) + + return chunked_hidden_states + + @staticmethod + def _get_global_attn_indices(is_index_global_attn): + """ compute global attn indices required throughout forward pass """ + # helper variable + num_global_attn_indices = tf.reduce_sum(tf.cast(is_index_global_attn, dtype=tf.dtypes.int32), axis=1) + + # max number of global attn indices in batch + max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices) + + # indices of global attn + is_index_global_attn_nonzero = tf.where(is_index_global_attn) + + # helper variable + is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims( + num_global_attn_indices, axis=-1 + ) + + # location of the non-padding values within global attention indices + is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn) + + # location of the padding values within global attention indices + is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn)) + + return ( + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ) + + def _concat_with_global_key_attn_probs( + self, + attn_scores, + key_vectors, + query_vectors, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + ): + batch_size = shape_list(key_vectors)[0] + + # select global key vectors + global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero) + # create only global key vectors + key_vectors_only_global = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_key_vectors, + shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), + ) + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global) + # (batch_size, max_num_global_attn_indices, seq_len, num_heads) + attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2)) + mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( + shape_list(attn_probs_from_global_key_trans)[-2:] + ) + mask = tf.ones(mask_shape) * -10000.0 + + # scatter mask + attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update( + attn_probs_from_global_key_trans, is_local_index_no_global_attn_nonzero, mask + ) + + # (batch_size, seq_len, num_heads, max_num_global_attn_indices) + attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1)) + + # concat to attn_probs + # (batch_size, seq_len, num_heads, extra attention count + 2*window+1) + attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1) + + return attn_scores + + def _compute_attn_output_with_global_indices( + self, + value_vectors, + attn_probs, + max_num_global_attn_indices, + is_index_global_attn_nonzero, + is_local_index_global_attn_nonzero, + ): + batch_size = shape_list(attn_probs)[0] + + # cut local attn probs to global only + attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices] + + # select global value vectors + global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero) + # create only global value vectors + value_vectors_only_global = tf.scatter_nd( + is_local_index_global_attn_nonzero, + global_value_vectors, + shape=(batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim), + ) + + # compute attn output only global + attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global) + + # reshape attn probs + attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:] + + # compute attn output with global + attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value( + attn_probs_without_global, value_vectors, self.one_sided_attn_window_size + ) + return attn_output_only_global + attn_output_without_global + + def _compute_global_attn_output_from_hidden( + self, + attn_output, + hidden_states, + max_num_global_attn_indices, + is_local_index_global_attn_nonzero, + is_index_global_attn_nonzero, + is_local_index_no_global_attn_nonzero, + is_index_masked, + training, + ): + seq_len, batch_size = shape_list(hidden_states)[:2] + + # prepare global hidden states + global_attn_hidden_states = tf.gather_nd(hidden_states, tf.reverse(is_index_global_attn_nonzero, axis=[1])) + global_attn_hidden_states = tf.scatter_nd( + tf.reverse(is_local_index_global_attn_nonzero, axis=[1]), + global_attn_hidden_states, + shape=(max_num_global_attn_indices, batch_size, self.embed_dim), + ) + + # global key, query, value + global_query_vectors_only_global = self.query_global(global_attn_hidden_states) + global_key_vectors = self.key_global(hidden_states) + global_value_vectors = self.value_global(hidden_states) + + # normalize + global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32)) + + # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim) + global_query_vectors_only_global = tf.transpose( + tf.reshape( + global_query_vectors_only_global, + (max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim), + ), + (1, 0, 2), + ) + + # (..., batch_size * self.num_heads, seq_len, head_dim) + global_key_vectors = tf.transpose( + tf.reshape(global_key_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2) + ) + + # (..., batch_size * self.num_heads, seq_len, head_dim) + global_value_vectors = tf.transpose( + tf.reshape(global_value_vectors, (-1, batch_size * self.num_heads, self.head_dim)), (1, 0, 2) + ) + + # compute attn scores + global_attn_scores = tf.matmul(global_query_vectors_only_global, tf.transpose(global_key_vectors, (0, 2, 1))) + + tf.debugging.assert_equal( + shape_list(global_attn_scores), + [batch_size * self.num_heads, max_num_global_attn_indices, seq_len], + message=f"global_attn_scores have the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is {shape_list(global_attn_scores)}.", + ) + + global_attn_scores = tf.reshape( + global_attn_scores, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + ) + + global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3)) + mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple( + shape_list(global_attn_scores_trans)[-2:] + ) + global_attn_mask = tf.ones(mask_shape) * -10000.0 + + # scatter mask + global_attn_scores_trans = tf.tensor_scatter_nd_update( + global_attn_scores_trans, is_local_index_no_global_attn_nonzero, global_attn_mask + ) + global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3)) + + # mask global attn scores + attn_mask = tf.broadcast_to(is_index_masked[:, None, None, :], shape_list(global_attn_scores)) + global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores) + + global_attn_scores = tf.reshape( + global_attn_scores, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len) + ) + + # compute global attn probs + global_attn_probs_float = tf.nn.softmax(global_attn_scores, axis=-1) + + # dropout + global_attn_probs = self.global_dropout(global_attn_probs_float, training=training) + + # global attn output + global_attn_output = tf.matmul(global_attn_probs, global_value_vectors) + + tf.debugging.assert_equal( + shape_list(global_attn_output), + [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim], + message=f"global_attn_output tensor has the wrong size. Size should be {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is {shape_list(global_attn_output)}.", + ) + + global_attn_output = tf.reshape( + global_attn_output, (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim) + ) + + # get only non zero global attn output + nonzero_global_attn_output = tf.gather_nd( + tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero + ) + nonzero_global_attn_output = tf.reshape( + nonzero_global_attn_output, (shape_list(is_local_index_global_attn_nonzero)[0], -1), + ) + + # overwrite values with global attention + attn_output = tf.tensor_scatter_nd_update( + attn_output, tf.reverse(is_index_global_attn_nonzero, axis=[1]), nonzero_global_attn_output + ) + + return attn_output + + +class TFLongformerAttention(tf.keras.layers.Layer): + def __init__(self, config, layer_id=0, **kwargs): + super().__init__(**kwargs) + self.self_attention = TFLongformerSelfAttention(config, layer_id, name="self") + self.dense_output = TFBertSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, inputs, training=False): + input_tensor, attention_mask, output_attentions = inputs + + self_outputs = self.self_attention([input_tensor, attention_mask, output_attentions], training=training) + attention_output = self.dense_output(self_outputs[0], input_tensor, training=training) + + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class TFLongformerLayer(tf.keras.layers.Layer): + def __init__(self, config, layer_id=0, **kwargs): + super().__init__(**kwargs) + self.attention = TFLongformerAttention(config, layer_id, name="attention") + self.intermediate = TFBertIntermediate(config, name="intermediate") + self.longformer_output = TFBertOutput(config, name="output") + + def call(self, inputs, training=False): + hidden_states, attention_mask, output_attentions = inputs + + attention_outputs = self.attention([hidden_states, attention_mask, output_attentions], training=training) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.longformer_output(intermediate_output, attention_output, training=training) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + + +class TFLongformerEncoder(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.layer = [ + TFLongformerLayer(config, i, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers) + ] + + def call(self, inputs, training=False): + hidden_states, attention_mask, output_attentions, output_hidden_states, padding_len = inputs + + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: + hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + all_hidden_states = all_hidden_states + (hidden_states_to_add,) + + layer_outputs = layer_module([hidden_states, attention_mask, output_attentions], training=training) + hidden_states = layer_outputs[0] + + if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: + hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states + all_hidden_states = all_hidden_states + (hidden_states_to_add,) + + outputs = (hidden_states,) + if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True: + outputs = outputs + (all_hidden_states,) + if cast_bool_to_primitive(output_attentions, self.output_attentions) is True: + outputs = outputs + (all_attentions,) + return outputs # outputs, (hidden states), (attentions) + + +@keras_serializable +class TFLongformerMainLayer(tf.keras.layers.Layer): + config_class = LongformerConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + if isinstance(config.attention_window, int): + assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value" + assert config.attention_window > 0, "`config.attention_window` has to be positive" + config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer + else: + assert len(config.attention_window) == config.num_hidden_layers, ( + "`len(config.attention_window)` should equal `config.num_hidden_layers`. " + f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" + ) + + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.pad_token_id = config.pad_token_id + self.attention_window = config.attention_window + + self.embeddings = TFRobertaEmbeddings(config, name="embeddings") + self.encoder = TFLongformerEncoder(config, name="encoder") + self.pooler = TFBertPooler(config, name="pooler") + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + raise NotImplementedError + + def call( + self, + inputs, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + if isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else attention_mask + global_attention_mask = inputs[2] if len(inputs) > 2 else attention_mask + token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids + position_ids = inputs[4] if len(inputs) > 4 else position_ids + inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds + output_attentions = inputs[6] if len(inputs) > 6 else output_attentions + output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states + return_dict = inputs[8] if len(inputs) > 8 else return_dict + assert len(inputs) <= 9, "Too many inputs." + elif isinstance(inputs, (dict, BatchEncoding)): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask", attention_mask) + global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) + token_type_ids = inputs.get("token_type_ids", token_type_ids) + position_ids = inputs.get("position_ids", position_ids) + inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) + output_attentions = inputs.get("output_attentions", output_attentions) + output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) + return_dict = inputs.get("return_dict", return_dict) + assert len(inputs) <= 9, "Too many inputs." + else: + input_ids = inputs + + output_attentions = output_attentions if output_attentions is not None else self.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states + return_dict = return_dict if return_dict is not None else self.return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) + if token_type_ids is None: + token_type_ids = tf.fill(input_shape, 0) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask) + + padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + pad_token_id=self.pad_token_id, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + + extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) + encoder_outputs = self.encoder( + [embedding_output, extended_attention_mask, output_attentions, output_hidden_states, padding_len], + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + # undo padding + if padding_len > 0: + # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) + sequence_output = sequence_output[:, :-padding_len] + + if not return_dict: + return (sequence_output, pooled_output,) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _pad_to_window_size( + self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, pad_token_id, + ): + """A helper function to pad tokens and mask to work with implementation of Longformer selfattention.""" + # padding + attention_window = ( + self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window) + ) + + assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}" + input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds) + batch_size, seq_len = input_shape[:2] + + padding_len = (attention_window - seq_len % attention_window) % attention_window + if padding_len > 0: + logger.info( + "Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format( + seq_len, seq_len + padding_len, attention_window + ) + ) + paddings = tf.constant([[0, 0], [0, padding_len]]) + if input_ids is not None: + input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id) + if position_ids is not None: + # pad with position_id = pad_token_id as in modeling_roberta.RobertaEmbeddings + position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id) + if inputs_embeds is not None: + input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id) + inputs_embeds_padding = self.embeddings(input_ids_padding) + inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2) + + attention_mask = tf.pad( + attention_mask, paddings, constant_values=False + ) # no attention on the padding tokens + token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0 + + return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds + + @staticmethod + def _merge_to_attention_mask(attention_mask: tf.Tensor, global_attention_mask: tf.Tensor): + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + return attention_mask + + +class TFLongformerPreTrainedModel(TFPreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = LongformerConfig + base_model_prefix = "longformer" + + @property + def dummy_inputs(self): + input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]) + # make sure global layers are initialized + attention_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) + global_attention_mask = tf.constant([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]]) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "global_attention_mask": global_attention_mask, + } + + +LONGFORMER_START_DOCSTRING = r""" + This model is a `tf.keras.Model `__ sub-class. + Use it as a regular TF 2.0 Keras Model and + refer to the TF 2.0 documentation for all matter related to general usage and behavior. + + .. note:: + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having + all the tensors in the first argument of the model call function: :obj:`model(inputs)`. + + If you choose this second option, there are three possibilities you can use to gather all the input Tensors + in the first positional argument : + + - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})` + + Parameters: + config (:class:`~transformers.LongformerConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + + +LONGFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`tf.Tensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.__call__` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + global_attention_mask (:obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to decide the attention given on each token, local attention or global attenion. + Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for + task-specific finetuning because it makes the model more flexible at representing the task. For example, + for classification, the token should be given global attention. For QA, all question tokens should also have + global attention. Please refer to the `Longformer paper `__ for more details. + Mask values selected in ``[0, 1]``: + ``0`` for local attention (a sliding window attention), + ``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + + token_type_ids (:obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. + return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a + plain tuple. +""" + + +@add_start_docstrings( + "The bare Longformer Model outputting raw hidden-states without any specific head on top.", + LONGFORMER_START_DOCSTRING, +) +class TFLongformerModel(TFLongformerPreTrainedModel): + """ + This class copies code from :class:`~transformers.RobertaModel` and overwrites standard self-attention with longformer self-attention to provide the ability to process + long sequences following the self-attention approach described in `Longformer: the Long-Document Transformer + `__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer self-attention + combines a local (sliding window) and global attention to extend to long documents without the O(n^2) increase in + memory and compute. + + The self-attention module :obj:`LongformerSelfAttention` implemented here supports the combination of local and + global attention but it lacks support for autoregressive attention and dilated attention. Autoregressive + and dilated attention are more relevant for autoregressive language modeling than finetuning on downstream + tasks. Future release will add support for autoregressive attention, but the support for dilated attention + requires a custom CUDA kernel to be memory and compute efficient. + + """ + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.longformer = TFLongformerMainLayer(config, name="longformer") + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + def call(self, inputs, **kwargs): + outputs = self.longformer(inputs, **kwargs) + return outputs + + +@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) +class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.longformer = TFLongformerMainLayer(config, name="longformer") + self.lm_head = TFRobertaLMHead(config, self.longformer.embeddings, name="lm_head") + + def get_output_embeddings(self): + return self.lm_head.decoder + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="allenai/longformer-base-4096", + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + inputs=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + training=False, + ): + r""" + labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.longformer.return_dict + if isinstance(inputs, (tuple, list)): + labels = inputs[9] if len(inputs) > 9 else labels + if len(inputs) > 9: + inputs = inputs[:9] + elif isinstance(inputs, (dict, BatchEncoding)): + labels = inputs.pop("labels", labels) + + outputs = self.longformer( + inputs, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output, training=training) + + loss = None if labels is None else self.compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + LONGFORMER_START_DOCSTRING, +) +class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.longformer = TFLongformerMainLayer(config, name="longformer") + self.qa_outputs = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="allenai/longformer-large-4096-finetuned-triviaqa", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + inputs=None, + attention_mask=None, + global_attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + start_positions=None, + end_positions=None, + training=False, + ): + r""" + start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.longformer.return_dict + if isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + global_attention_mask = inputs[2] + start_positions = inputs[9] if len(inputs) > 9 else start_positions + end_positions = inputs[10] if len(inputs) > 10 else end_positions + if len(inputs) > 9: + inputs = inputs[:9] + elif isinstance(inputs, (dict, BatchEncoding)): + input_ids = inputs.get("input_ids", inputs) + global_attention_mask = inputs.get("global_attention_mask", global_attention_mask) + start_positions = inputs.pop("start_positions", start_positions) + end_positions = inputs.pop("end_positions", start_positions) + else: + input_ids = inputs + + # set global attention on question tokens + if global_attention_mask is None and input_ids is not None: + if input_ids is None: + logger.warning( + "It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set." + ) + elif tf.where(input_ids == self.config.sep_token_id).shape[0] != 3 * input_ids.shape[0]: + logger.warning( + f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error." + ) + else: + logger.info("Initializing global attention on question tokens...") + # put global attention on all tokens until `config.sep_token_id` is reached + sep_token_indices = tf.where(input_ids == self.config.sep_token_id) + global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices) + + outputs = self.longformer( + inputs, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 8b97ef3fc71..0730dce654f 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -33,6 +33,7 @@ if is_torch_available(): LongformerForTokenClassification, LongformerForQuestionAnswering, LongformerForMultipleChoice, + LongformerSelfAttention, ) @@ -325,7 +326,209 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs) +@require_torch class LongformerModelIntegrationTest(unittest.TestCase): + def _get_hidden_states(self): + return torch.tensor( + [ + [ + [ + 4.98332758e-01, + 2.69175139e00, + -7.08081422e-03, + 1.04915401e00, + -1.83476661e00, + 7.67220476e-01, + 2.98580543e-01, + 2.84803992e-02, + ], + [ + -7.58357372e-01, + 4.20635998e-01, + -4.04739919e-02, + 1.59924145e-01, + 2.05135748e00, + -1.15997978e00, + 5.37166397e-01, + 2.62873606e-01, + ], + [ + -1.69438001e00, + 4.17574660e-01, + -1.49196962e00, + -1.76483717e00, + -1.94566312e-01, + -1.71183858e00, + 7.72903565e-01, + -1.11557056e00, + ], + [ + 5.44028163e-01, + 2.05466114e-01, + -3.63045868e-01, + 2.41865062e-01, + 3.20348382e-01, + -9.05611176e-01, + -1.92690727e-01, + -1.19917547e00, + ], + ] + ], + dtype=torch.float32, + device=torch_device, + ) + + def test_diagonalize(self): + hidden_states = self._get_hidden_states() + hidden_states = hidden_states.reshape((1, 8, 4)) # set seq length = 8, hidden dim = 4 + chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2) + window_overlap_size = chunked_hidden_states.shape[2] + self.assertTrue(window_overlap_size == 4) + + padded_hidden_states = LongformerSelfAttention._pad_and_diagonalize(chunked_hidden_states) + + self.assertTrue(padded_hidden_states.shape[-1] == chunked_hidden_states.shape[-1] + window_overlap_size - 1) + + # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] + self.assertTrue(torch.allclose(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], atol=1e-3)) + self.assertTrue( + torch.allclose( + padded_hidden_states[0, 0, 0, 4:], + torch.zeros((3,), device=torch_device, dtype=torch.float32), + atol=1e-3, + ) + ) + # last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629] + self.assertTrue(torch.allclose(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], atol=1e-3)) + self.assertTrue( + torch.allclose( + padded_hidden_states[0, 0, -1, :3], + torch.zeros((3,), device=torch_device, dtype=torch.float32), + atol=1e-3, + ) + ) + + def test_pad_and_transpose_last_two_dims(self): + hidden_states = self._get_hidden_states() + self.assertTrue(hidden_states.shape, (1, 8, 4)) + padding = (0, 0, 0, 1) + + padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding) + self.assertTrue(padded_hidden_states.shape, (1, 8, 5)) + + expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32) + self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6)) + self.assertTrue(torch.allclose(hidden_states[0, -1, :], padded_hidden_states.view(1, -1)[0, 24:32], atol=1e-6)) + + def test_chunk(self): + hidden_states = self._get_hidden_states() + batch_size = 1 + seq_length = 8 + hidden_size = 4 + hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size)) + + chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2) + + # expected slices across chunk and seq length dim + expected_slice_along_seq_length = torch.tensor( + [0.4983, -0.7584, -1.6944], device=torch_device, dtype=torch.float32 + ) + expected_slice_along_chunk = torch.tensor( + [0.4983, -1.8348, -0.7584, 2.0514], device=torch_device, dtype=torch.float32 + ) + + self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3)) + self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3)) + self.assertTrue(chunked_hidden_states.shape, (1, 3, 4, 4)) + + def test_mask_invalid_locations(self): + hidden_states = self._get_hidden_states() + + batch_size = 1 + seq_length = 8 + hidden_size = 4 + hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size)) + chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2) + + hid_states_1 = chunked_hidden_states.clone() + LongformerSelfAttention._mask_invalid_locations(hid_states_1, 1) + self.assertTrue(torch.isinf(hid_states_1).sum().item() == 8) + + hid_states_2 = chunked_hidden_states.clone() + LongformerSelfAttention._mask_invalid_locations(hid_states_2, 2) + self.assertTrue(torch.isinf(hid_states_2).sum().item() == 24) + + hid_states_3 = chunked_hidden_states.clone()[:, :, :, :3] + LongformerSelfAttention._mask_invalid_locations(hid_states_3, 2) + self.assertTrue(torch.isinf(hid_states_3).sum().item() == 24) + + hid_states_4 = chunked_hidden_states.clone()[:, :, 2:, :] + LongformerSelfAttention._mask_invalid_locations(hid_states_4, 2) + self.assertTrue(torch.isinf(hid_states_4).sum().item() == 12) + + def test_layer_local_attn(self): + model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny") + model.eval() + layer = model.encoder.layer[0].attention.self.to(torch_device) + hidden_states = self._get_hidden_states() + batch_size, seq_length, hidden_size = hidden_states.size() + attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device) + attention_mask[:, :, :, -2:] = -10000 + output_hidden_states = layer(hidden_states, attention_mask)[0] + + self.assertTrue(output_hidden_states.shape, (1, 4, 8)) + self.assertTrue( + torch.allclose( + output_hidden_states[0, 1], + torch.tensor( + [0.0019, 0.0122, -0.0171, -0.0256, -0.0300, 0.0173, -0.0115, 0.0048], + dtype=torch.float32, + device=torch_device, + ), + atol=1e-3, + ) + ) + + def test_layer_global_attn(self): + model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny") + model.eval() + layer = model.encoder.layer[0].attention.self.to(torch_device) + hidden_states = torch.cat([self._get_hidden_states(), self._get_hidden_states() - 0.5], dim=0) + batch_size, seq_length, hidden_size = hidden_states.size() + attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device) + + # create attn mask + attention_mask[0, :, :, -2:] = 10000.0 + attention_mask[0, :, :, -1:] = -10000.0 + attention_mask[1, :, :, 1:] = 10000.0 + output_hidden_states = layer(hidden_states, attention_mask)[0] + + self.assertTrue(output_hidden_states.shape, (2, 4, 8)) + + self.assertTrue( + torch.allclose( + output_hidden_states[0, 2], + torch.tensor( + [-0.0651, -0.0393, 0.0309, -0.0342, -0.0066, -0.0155, -0.0209, -0.0494], + dtype=torch.float32, + device=torch_device, + ), + atol=1e-3, + ) + ) + + self.assertTrue( + torch.allclose( + output_hidden_states[1, -2], + torch.tensor( + [-0.0405, -0.0384, 0.0396, -0.0374, -0.0341, 0.0136, 0.0014, -0.0571], + dtype=torch.float32, + device=torch_device, + ), + atol=1e-3, + ) + ) + @slow def test_inference_no_head(self): model = LongformerModel.from_pretrained("allenai/longformer-base-4096") @@ -371,13 +574,13 @@ class LongformerModelIntegrationTest(unittest.TestCase): input_ids = torch.tensor( [[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device ) # long input + input_ids = input_ids.to(torch_device) loss, prediction_scores = model(input_ids, labels=input_ids) expected_loss = torch.tensor(0.0074, device=torch_device) expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device) expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device) - input_ids = input_ids.to(torch_device) self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4)) self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4)) diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py new file mode 100644 index 00000000000..5b8e9ae26ca --- /dev/null +++ b/tests/test_modeling_tf_longformer.py @@ -0,0 +1,531 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_tf_available +from transformers.testing_utils import require_tf, slow + +from .test_configuration_common import ConfigTester +from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor + + +if is_tf_available(): + import tensorflow as tf + from transformers import ( + LongformerConfig, + TFLongformerModel, + TFLongformerForMaskedLM, + TFLongformerForQuestionAnswering, + TFLongformerSelfAttention, + ) + + def shape_list(x): + """ + copied from transformers.modeling_tf_utils + """ + static = x.shape.as_list() + dynamic = tf.shape(x) + return [dynamic[i] if s is None else s for i, s in enumerate(static)] + + +class TFLongformerModelTester: + def __init__( + self, parent, + ): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_input_mask = True + self.use_token_type_ids = True + self.use_labels = True + self.vocab_size = 99 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 16 + self.type_sequence_label_size = 2 + self.initializer_range = 0.02 + self.num_labels = 3 + self.num_choices = 4 + self.scope = None + self.attention_window = 4 + + # `ModelTesterMixin.test_attention_outputs` is expecting attention tensors to be of size + # [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention + # returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1] + # because its local attention only attends to `self.attention_window` and one before and one after + self.key_length = self.attention_window + 2 + + # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for + # the `test_attention_outputs` and `test_hidden_states_output` tests + self.encoder_seq_length = ( + self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window + ) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = LongformerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + attention_window=self.attention_window, + ) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def create_and_check_attention_mask_determinism( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFLongformerModel(config=config) + + attention_mask = tf.ones(input_ids.shape, dtype=tf.dtypes.int32) + output_with_mask = model(input_ids, attention_mask=attention_mask)[0] + output_without_mask = model(input_ids)[0] + tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4) + + def create_and_check_longformer_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFLongformerModel(config=config) + sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids) + sequence_output, pooled_output = model(input_ids) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size]) + + def create_and_check_longformer_model_with_global_attention_mask( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFLongformerModel(config=config) + half_input_mask_length = shape_list(input_mask)[-1] // 2 + global_attention_mask = tf.concat( + [ + tf.zeros_like(input_mask)[:, :half_input_mask_length], + tf.ones_like(input_mask)[:, half_input_mask_length:], + ], + axis=-1, + ) + + sequence_output, pooled_output = model( + input_ids, + attention_mask=input_mask, + global_attention_mask=global_attention_mask, + token_type_ids=token_type_ids, + ) + sequence_output, pooled_output = model( + input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask + ) + sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask) + + result = { + "sequence_output": sequence_output, + "pooled_output": pooled_output, + } + self.parent.assertListEqual( + shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size] + ) + self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size]) + + def create_and_check_longformer_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFLongformerForMaskedLM(config=config) + loss, prediction_scores = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels + ) + result = { + "loss": loss, + "prediction_scores": prediction_scores, + } + self.parent.assertListEqual( + shape_list(result["prediction_scores"]), [self.batch_size, self.seq_length, self.vocab_size] + ) + + def create_and_check_longformer_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFLongformerForQuestionAnswering(config=config) + loss, start_logits, end_logits = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + # global attention mask has to be partly defined + # to trace all weights + global_attention_mask = tf.concat( + [tf.zeros_like(input_ids)[:, :-1], tf.ones_like(input_ids)[:, -1:]], axis=-1, + ) + + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + "global_attention_mask": global_attention_mask, + } + return config, inputs_dict + + def prepare_config_and_inputs_for_question_answering(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + # Replace sep_token_id by some random id + input_ids = tf.where(input_ids == config.sep_token_id, 0, input_ids) + # Make sure there are exactly three sep_token_id + input_ids = tf.concat([input_ids[:, :-3], tf.ones_like(input_ids)[:, -3:] * config.sep_token_id], axis=-1) + input_mask = tf.ones_like(input_ids) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + +@require_tf +class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): + test_pruning = False # pruning is not supported + test_headmasking = False # head masking is not supported + test_torchscript = False + + all_model_classes = ( + (TFLongformerModel, TFLongformerForMaskedLM, TFLongformerForQuestionAnswering,) if is_tf_available() else () + ) + + def setUp(self): + self.model_tester = TFLongformerModelTester(self) + self.config_tester = ConfigTester(self, config_class=LongformerConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_longformer_model_attention_mask_determinism(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) + + def test_longformer_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_model(*config_and_inputs) + + def test_longformer_model_global_attention_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) + + def test_longformer_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) + + def test_longformer_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() + self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) + + +@require_tf +class TFLongformerModelIntegrationTest(unittest.TestCase): + def _get_hidden_states(self): + return tf.convert_to_tensor( + [ + [ + [ + 4.98332758e-01, + 2.69175139e00, + -7.08081422e-03, + 1.04915401e00, + -1.83476661e00, + 7.67220476e-01, + 2.98580543e-01, + 2.84803992e-02, + ], + [ + -7.58357372e-01, + 4.20635998e-01, + -4.04739919e-02, + 1.59924145e-01, + 2.05135748e00, + -1.15997978e00, + 5.37166397e-01, + 2.62873606e-01, + ], + [ + -1.69438001e00, + 4.17574660e-01, + -1.49196962e00, + -1.76483717e00, + -1.94566312e-01, + -1.71183858e00, + 7.72903565e-01, + -1.11557056e00, + ], + [ + 5.44028163e-01, + 2.05466114e-01, + -3.63045868e-01, + 2.41865062e-01, + 3.20348382e-01, + -9.05611176e-01, + -1.92690727e-01, + -1.19917547e00, + ], + ] + ], + dtype=tf.float32, + ) + + def test_diagonalize(self): + hidden_states = self._get_hidden_states() + hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4 + chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) + window_overlap_size = shape_list(chunked_hidden_states)[2] + self.assertTrue(window_overlap_size == 4) + + padded_hidden_states = TFLongformerSelfAttention._pad_and_diagonalize(chunked_hidden_states) + + self.assertTrue( + shape_list(padded_hidden_states)[-1] == shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1 + ) + + # first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000] + tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3) + tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3) + + # last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629] + tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3) + tf.debugging.assert_near( + padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3 + ) + + def test_pad_and_transpose_last_two_dims(self): + hidden_states = self._get_hidden_states() + self.assertTrue(shape_list(hidden_states), [1, 8, 4]) + + # pad along seq length dim + paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32) + + padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings) + self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5]) + + expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32) + tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6) + tf.debugging.assert_near( + hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6 + ) + + def test_mask_invalid_locations(self): + hidden_states = self._get_hidden_states() + batch_size = 1 + seq_length = 8 + hidden_size = 4 + hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) + hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) + + hid_states_1 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states, 1) + hid_states_2 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states, 2) + hid_states_3 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2) + hid_states_4 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2) + + self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8) + self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24) + self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24) + self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12) + + def test_chunk(self): + hidden_states = self._get_hidden_states() + batch_size = 1 + seq_length = 8 + hidden_size = 4 + hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size)) + + chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2) + + # expected slices across chunk and seq length dim + expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32) + expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32) + + self.assertTrue(shape_list(chunked_hidden_states) == [1, 3, 4, 4]) + tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3) + tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3) + + def test_layer_local_attn(self): + model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False) + layer = model.longformer.encoder.layer[0].attention.self_attention + hidden_states = self._get_hidden_states() + batch_size, seq_length, hidden_size = hidden_states.shape + + attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32) + attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask) + + output_hidden_states = layer([hidden_states, attention_mask, None])[0] + + expected_slice = tf.convert_to_tensor( + [0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32 + ) + + self.assertTrue(output_hidden_states.shape, (1, 4, 8)) + tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3) + + def test_layer_global_attn(self): + model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False) + layer = model.longformer.encoder.layer[0].attention.self_attention + hidden_states = self._get_hidden_states() + + hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0) + batch_size, seq_length, hidden_size = hidden_states.shape + + # create attn mask + attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) + attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32) + + attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1) + attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1) + attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2) + attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0) + + output_hidden_states = layer([hidden_states, attention_mask, None])[0] + + self.assertTrue(output_hidden_states.shape, (2, 4, 8)) + expected_slice_0 = tf.convert_to_tensor( + [-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.dtypes.float32 + ) + + expected_slice_1 = tf.convert_to_tensor( + [-0.04055, -0.038399, 0.0396, -0.03735, -0.03415, 0.01357, 0.00145, -0.05709], dtype=tf.dtypes.float32 + ) + + tf.debugging.assert_near(output_hidden_states[0, 2], expected_slice_0, rtol=1e-3) + tf.debugging.assert_near(output_hidden_states[1, -2], expected_slice_1, rtol=1e-3) + + @slow + def test_inference_no_head(self): + model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096") + + # 'Hello world!' + input_ids = tf.convert_to_tensor([[0, 20920, 232, 328, 1437, 2]], dtype=tf.dtypes.int32) + attention_mask = tf.ones(shape_list(input_ids), dtype=tf.dtypes.int32) + + output = model(input_ids, attention_mask=attention_mask)[0] + output_without_mask = model(input_ids)[0] + + expected_output_slice = tf.convert_to_tensor( + [0.0549, 0.1087, -0.1119, -0.0368, 0.0250], dtype=tf.dtypes.float32 + ) + + tf.debugging.assert_near(output[0, 0, -5:], expected_output_slice, rtol=1e-3) + tf.debugging.assert_near(output_without_mask[0, 0, -5:], expected_output_slice, rtol=1e-3) + + @slow + def test_inference_no_head_long(self): + model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096") + + # 'Hello world! ' repeated 1000 times + input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.dtypes.int32) + + attention_mask = tf.ones(shape_list(input_ids), dtype=tf.dtypes.int32) + global_attention_mask = tf.zeros(shape_list(input_ids), dtype=tf.dtypes.int32) + # Set global attention on a few random positions + global_attention_mask = tf.tensor_scatter_nd_update( + global_attention_mask, tf.constant([[0, 1], [0, 4], [0, 21]]), tf.constant([1, 1, 1]) + ) + + output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0] + + expected_output_sum = tf.constant(74585.875) + expected_output_mean = tf.constant(0.024267) + + # assert close + tf.debugging.assert_near(tf.reduce_sum(output), expected_output_sum, rtol=1e-4) + tf.debugging.assert_near(tf.reduce_mean(output), expected_output_mean, rtol=1e-4) + + @slow + def test_inference_masked_lm_long(self): + model = TFLongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096") + + # 'Hello world! ' repeated 1000 times + input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.dtypes.int32) + + loss, prediction_scores = model(input_ids, labels=input_ids) + + expected_loss = tf.constant(0.0073798) + expected_prediction_scores_sum = tf.constant(-610476600.0) + expected_prediction_scores_mean = tf.constant(-3.03477) + + # assert close + tf.debugging.assert_near(tf.reduce_mean(loss), expected_loss, rtol=1e-4) + tf.debugging.assert_near(tf.reduce_sum(prediction_scores), expected_prediction_scores_sum, rtol=1e-4) + tf.debugging.assert_near(tf.reduce_mean(prediction_scores), expected_prediction_scores_mean, rtol=1e-4)