TF Longformer (#5764)

* improve names and tests longformer

* more and better tests for longformer

* add first tf test

* finalize tf basic op functions

* fix merge

* tf shape test passes

* narrow down discrepancies

* make longformer local attn tf work

* correct tf longformer

* add first global attn function

* add more global longformer func

* advance tf longformer

* finish global attn

* upload big model

* finish all tests

* correct false any statement

* fix common tests

* make all tests pass except keras save load

* fix some tests

* fix torch test import

* finish tests

* fix test

* fix torch tf tests

* add docs

* finish docs

* Update src/transformers/modeling_longformer.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/modeling_tf_longformer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply Lysandres suggestions

* reverse to assert statement because function will fail otherwise

* applying sylvains recommendations

* Update src/transformers/modeling_longformer.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* Update src/transformers/modeling_tf_longformer.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Patrick von Platen 2020-08-10 23:25:06 +02:00 committed by GitHub
parent 3425936643
commit 00bb0b25ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 2371 additions and 186 deletions

View File

@ -102,3 +102,25 @@ LongformerForQuestionAnswering
.. autoclass:: transformers.LongformerForQuestionAnswering
:members:
TFLongformerModel
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerModel
:members:
TFLongformerForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForMaskedLM
:members:
TFLongformerForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForQuestionAnswering
:members:

View File

@ -399,6 +399,7 @@ if is_torch_available():
LongformerForMultipleChoice,
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LongformerSelfAttention,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
@ -568,6 +569,14 @@ if is_tf_available():
TFGPT2PreTrainedModel,
)
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerModel,
TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering,
TFLongformerSelfAttention,
)
from .modeling_tf_mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertModel,

View File

@ -71,7 +71,6 @@ def _get_question_end_index(input_ids, sep_token_id):
assert (
sep_token_indices.shape[0] == 3 * batch_size
), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@ -81,7 +80,6 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
before `sep_token_id` if `before_sep_token is True` else after
`sep_token_id`.
"""
question_end_index = _get_question_end_index(input_ids, sep_token_id)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
@ -131,6 +129,172 @@ class LongformerSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
-ve: no attention
0: local attention
+ve: global attention
"""
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
# is index masked or global attention
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
hidden_states = hidden_states.transpose(0, 1)
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = hidden_states.size()
assert (
embed_dim == self.embed_dim
), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
# normalize query
query_vectors /= math.sqrt(self.head_dim)
query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)
# pad local attention probs
attn_scores += diagonal_mask
assert list(attn_scores.size()) == [
batch_size,
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
# free memory
del global_key_attn_scores
attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = attn_probs_fp32.type_as(attn_scores)
# free memory
del attn_probs_fp32
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
# apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add
if is_global_attn:
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
if is_global_attn:
global_attn_output = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
)
# get only non zero global attn output
nonzero_global_attn_output = global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
]
# overwrite values with global attention
attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
)
attn_output = attn_output.transpose(0, 1)
if output_attentions:
if is_global_attn:
# With global attention, return global attention probabilities only
# batch_size x num_heads x max_num_global_attention_tokens x sequence_length
# which is the attention weights from tokens with global attention to all tokens
# It doesn't not return local attention
# In case of variable number of global attantion in the rows of a batch,
# attn_probs are padded with -10000.0 attention scores
attn_probs = attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_probs = attn_probs.permute(0, 2, 1, 3)
outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
return outputs
@staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
"""pads rows and then flips rows and columns"""
@ -143,8 +307,20 @@ class LongformerSelfAttention(nn.Module):
return hidden_states_padded
@staticmethod
def _pad_by_window_overlap_except_last_row(chunked_hidden_states):
"""shift every row 1 step right, converting columns into diagonals"""
def _pad_and_diagonalize(chunked_hidden_states):
"""shift every row 1 step right, converting columns into diagonals.
Example:
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4
(pad & diagonilize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
chunked_hidden_states = F.pad(
chunked_hidden_states, (0, window_overlap + 1)
@ -181,7 +357,8 @@ class LongformerSelfAttention(nn.Module):
chunk_stride[1] = chunk_stride[1] // 2
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
def _mask_invalid_locations(self, input_tensor, affected_seq_len) -> torch.Tensor:
@staticmethod
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
beginning_mask = beginning_mask_2d[None, :, None, :]
ending_mask = beginning_mask.flip(dims=(1, 3))
@ -243,6 +420,7 @@ class LongformerSelfAttention(nn.Module):
diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
:, :, -(window_overlap + 1) : -1, window_overlap + 1 :
]
diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
:, 0, : window_overlap - 1, 1 - window_overlap :
]
@ -261,11 +439,13 @@ class LongformerSelfAttention(nn.Module):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`"""
batch_size, seq_len, num_heads, head_dim = value.size()
assert seq_len % (window_overlap * 2) == 0
assert attn_probs.size()[:3] == value.size()[:3]
assert attn_probs.size(3) == 2 * window_overlap + 1
chunks_count = seq_len // window_overlap - 1
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1
)
@ -287,178 +467,11 @@ class LongformerSelfAttention(nn.Module):
)
chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
chunked_attn_probs = self._pad_by_window_overlap_except_last_row(chunked_attn_probs)
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
Padding to `attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
-ve: no attention
0: local attention
+ve: global attention
"""
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
# is index masked or global attention
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
hidden_states = hidden_states.transpose(0, 1)
# project hidden states
query_vectors = self.query(hidden_states)
key_vectors = self.key(hidden_states)
value_vectors = self.value(hidden_states)
seq_len, batch_size, embed_dim = hidden_states.size()
assert (
embed_dim == self.embed_dim
), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
# normalize query
query_vectors /= math.sqrt(self.head_dim)
query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# attn_probs = (batch_size, seq_len, num_heads, window*2+1)
attn_scores = self._sliding_chunks_query_key_matmul(
query_vectors, key_vectors, self.one_sided_attn_window_size
)
# values to pad for attention probs
remove_from_windowed_attention_mask = (attention_mask != 0).unsqueeze(dim=-1).unsqueeze(dim=-1)
# cast to fp32/fp16 then replace 1's with -inf
float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
remove_from_windowed_attention_mask, -10000.0
)
# diagonal mask with zeros everywhere and -inf inplace of padding
diagonal_mask = self._sliding_chunks_query_key_matmul(
float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
)
# pad local attention probs
attn_scores += diagonal_mask
assert list(attn_scores.size()) == [
batch_size,
seq_len,
self.num_heads,
self.one_sided_attn_window_size * 2 + 1,
], f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads}, {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
# concat to attn_probs
# (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
# free memory
del global_key_attn_scores
attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = attn_probs_fp32.type_as(attn_scores)
# free memory
del attn_probs_fp32
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_probs = torch.masked_fill(attn_probs, is_index_masked.unsqueeze(-1).unsqueeze(-1), 0.0)
# apply dropout
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
# compute local attention output with global attention value and add
if is_global_attn:
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
)
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
attn_probs, value_vectors, self.one_sided_attn_window_size
)
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
# compute value for global attention and overwrite to attention output
# TODO: remove the redundant computation
if is_global_attn:
global_attn_output = self._compute_global_attn_output_from_hidden(
hidden_states=hidden_states,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=is_index_masked,
)
# get only non zero global attn output
nonzero_global_attn_output = global_attn_output[
is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
]
# overwrite values with global attention
attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
len(is_local_index_global_attn_nonzero[0]), -1
)
attn_output = attn_output.transpose(0, 1)
if output_attentions:
if is_global_attn:
# With global attention, return global attention probabilities only
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights from all tokens to all tokens for global attention
# It doesn't not return local attention. Only tokens with global attention have values > 0.0
attn_probs = attn_probs[:, :, :, :max_num_global_attn_indices]
# pad attn_probs to max length with 0.0 since global attn did not attend there
window_size = self.one_sided_attn_window_size * 2 + 1
attn_probs = F.pad(attn_probs, (0, window_size - max_num_global_attn_indices), value=0.0,)
attn_probs = attn_probs.permute(0, 2, 1, 3)
else:
# without global attention, return local attention probabilities
# batch_size x num_heads x sequence_length x window_size
# which is the attention weights of every token attending to its neighbours
attn_probs = attn_probs.permute(0, 2, 1, 3)
outputs = (attn_output, attn_probs) if output_attentions else (attn_output,)
return outputs
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
@ -503,12 +516,16 @@ class LongformerSelfAttention(nn.Module):
key_vectors_only_global = key_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
attn_probs_from_global_key[
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
] = -10000.0
return attn_probs_from_global_key
def _compute_attn_output_with_global_indices(
@ -600,7 +617,7 @@ class LongformerSelfAttention(nn.Module):
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(is_index_masked.unsqueeze(1).unsqueeze(2), -10000.0,)
global_attn_scores = global_attn_scores.masked_fill(is_index_masked[:, None, None, :], -10000.0,)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
@ -754,7 +771,6 @@ class LongformerPreTrainedModel(PreTrainedModel):
LONGFORMER_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
usage and behavior.
@ -823,13 +839,13 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
)
class LongformerModel(LongformerPreTrainedModel):
"""
This class overrides :class:`~transformers.RobertaModel` to provide the ability to process
long sequences following the selfattention approach described in `Longformer: the Long-Document Transformer
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer selfattention
This class copied code from :class:`~transformers.RobertaModel` and overwrote standard self-attention with longformer self-attention to provide the ability to process
long sequences following the self-attention approach described in `Longformer: the Long-Document Transformer
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer self-attention
combines a local (sliding window) and global attention to extend to long documents without the O(n^2) increase in
memory and compute.
The selfattention module `LongformerSelfAttention` implemented here supports the combination of local and
The self-attention module `LongformerSelfAttention` implemented here supports the combination of local and
global attention but it lacks support for autoregressive attention and dilated attention. Autoregressive
and dilated attention are more relevant for autoregressive language modeling than finetuning on downstream
tasks. Future release will add support for autoregressive attention, but the support for dilated attention
@ -883,7 +899,7 @@ class LongformerModel(LongformerPreTrainedModel):
inputs_embeds: torch.Tensor,
pad_token_id: int,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
"""A helper function to pad tokens and mask to work with implementation of Longformer self-attention."""
# padding
attention_window = (
self.config.attention_window
@ -1053,6 +1069,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
self.init_weights()
def get_output_embeddings(self):
return self.lm_head.decoder
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1315,11 +1334,14 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# set global attention on question tokens
if global_attention_mask is None:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
if input_ids is None:
logger.warning(
"It is not possible to automatically generate the `global_attention_mask` because input_ids is None. Please make sure that it is correctly set."
)
else:
# set global attention on question tokens automatically
global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
outputs = self.longformer(
input_ids,
@ -1504,7 +1526,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# set global attention on question tokens
if global_attention_mask is None:
if global_attention_mask is None and input_ids is not None:
logger.info("Initializing global attention on multiple choice...")
# put global attention on all tokens after `config.sep_token_id`
global_attention_mask = torch.stack(

View File

@ -29,6 +29,7 @@ from .configuration_auto import (
ElectraConfig,
FlaubertConfig,
GPT2Config,
LongformerConfig,
MobileBertConfig,
OpenAIGPTConfig,
RobertaConfig,
@ -93,6 +94,7 @@ from .modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel,
)
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
from .modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
@ -149,6 +151,7 @@ TF_MODEL_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertModel),
(CamembertConfig, TFCamembertModel),
(XLMRobertaConfig, TFXLMRobertaModel),
(LongformerConfig, TFLongformerModel),
(RobertaConfig, TFRobertaModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
@ -191,6 +194,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
@ -226,6 +230,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(MobileBertConfig, TFMobileBertForMaskedLM),
@ -259,6 +264,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(LongformerConfig, TFLongformerForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ if is_torch_available():
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LongformerForMultipleChoice,
LongformerSelfAttention,
)
@ -325,7 +326,209 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
@require_torch
class LongformerModelIntegrationTest(unittest.TestCase):
def _get_hidden_states(self):
return torch.tensor(
[
[
[
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]
],
dtype=torch.float32,
device=torch_device,
)
def test_diagonalize(self):
hidden_states = self._get_hidden_states()
hidden_states = hidden_states.reshape((1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
window_overlap_size = chunked_hidden_states.shape[2]
self.assertTrue(window_overlap_size == 4)
padded_hidden_states = LongformerSelfAttention._pad_and_diagonalize(chunked_hidden_states)
self.assertTrue(padded_hidden_states.shape[-1] == chunked_hidden_states.shape[-1] + window_overlap_size - 1)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
self.assertTrue(torch.allclose(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], atol=1e-3))
self.assertTrue(
torch.allclose(
padded_hidden_states[0, 0, 0, 4:],
torch.zeros((3,), device=torch_device, dtype=torch.float32),
atol=1e-3,
)
)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
self.assertTrue(torch.allclose(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], atol=1e-3))
self.assertTrue(
torch.allclose(
padded_hidden_states[0, 0, -1, :3],
torch.zeros((3,), device=torch_device, dtype=torch.float32),
atol=1e-3,
)
)
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertTrue(hidden_states.shape, (1, 8, 4))
padding = (0, 0, 0, 1)
padded_hidden_states = LongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, padding)
self.assertTrue(padded_hidden_states.shape, (1, 8, 5))
expected_added_dim = torch.zeros((5,), device=torch_device, dtype=torch.float32)
self.assertTrue(torch.allclose(expected_added_dim, padded_hidden_states[0, -1, :], atol=1e-6))
self.assertTrue(torch.allclose(hidden_states[0, -1, :], padded_hidden_states.view(1, -1)[0, 24:32], atol=1e-6))
def test_chunk(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size))
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = torch.tensor(
[0.4983, -0.7584, -1.6944], device=torch_device, dtype=torch.float32
)
expected_slice_along_chunk = torch.tensor(
[0.4983, -1.8348, -0.7584, 2.0514], device=torch_device, dtype=torch.float32
)
self.assertTrue(torch.allclose(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, atol=1e-3))
self.assertTrue(torch.allclose(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, atol=1e-3))
self.assertTrue(chunked_hidden_states.shape, (1, 3, 4, 4))
def test_mask_invalid_locations(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = hidden_states.reshape((batch_size, seq_length, hidden_size))
chunked_hidden_states = LongformerSelfAttention._chunk(hidden_states, window_overlap=2)
hid_states_1 = chunked_hidden_states.clone()
LongformerSelfAttention._mask_invalid_locations(hid_states_1, 1)
self.assertTrue(torch.isinf(hid_states_1).sum().item() == 8)
hid_states_2 = chunked_hidden_states.clone()
LongformerSelfAttention._mask_invalid_locations(hid_states_2, 2)
self.assertTrue(torch.isinf(hid_states_2).sum().item() == 24)
hid_states_3 = chunked_hidden_states.clone()[:, :, :, :3]
LongformerSelfAttention._mask_invalid_locations(hid_states_3, 2)
self.assertTrue(torch.isinf(hid_states_3).sum().item() == 24)
hid_states_4 = chunked_hidden_states.clone()[:, :, 2:, :]
LongformerSelfAttention._mask_invalid_locations(hid_states_4, 2)
self.assertTrue(torch.isinf(hid_states_4).sum().item() == 12)
def test_layer_local_attn(self):
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
model.eval()
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
attention_mask[:, :, :, -2:] = -10000
output_hidden_states = layer(hidden_states, attention_mask)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
self.assertTrue(
torch.allclose(
output_hidden_states[0, 1],
torch.tensor(
[0.0019, 0.0122, -0.0171, -0.0256, -0.0300, 0.0173, -0.0115, 0.0048],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
def test_layer_global_attn(self):
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
model.eval()
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = torch.cat([self._get_hidden_states(), self._get_hidden_states() - 0.5], dim=0)
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
# create attn mask
attention_mask[0, :, :, -2:] = 10000.0
attention_mask[0, :, :, -1:] = -10000.0
attention_mask[1, :, :, 1:] = 10000.0
output_hidden_states = layer(hidden_states, attention_mask)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
self.assertTrue(
torch.allclose(
output_hidden_states[0, 2],
torch.tensor(
[-0.0651, -0.0393, 0.0309, -0.0342, -0.0066, -0.0155, -0.0209, -0.0494],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
self.assertTrue(
torch.allclose(
output_hidden_states[1, -2],
torch.tensor(
[-0.0405, -0.0384, 0.0396, -0.0374, -0.0341, 0.0136, 0.0014, -0.0571],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
@ -371,13 +574,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
input_ids = torch.tensor(
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
) # long input
input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids)
expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
expected_prediction_scores_mean = torch.tensor(-3.0348, device=torch_device)
input_ids = input_ids.to(torch_device)
self.assertTrue(torch.allclose(loss, expected_loss, atol=1e-4))
self.assertTrue(torch.allclose(prediction_scores.sum(), expected_prediction_scores_sum, atol=1e-4))

View File

@ -0,0 +1,531 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import (
LongformerConfig,
TFLongformerModel,
TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering,
TFLongformerSelfAttention,
)
def shape_list(x):
"""
copied from transformers.modeling_tf_utils
"""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
class TFLongformerModelTester:
def __init__(
self, parent,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = True
self.use_input_mask = True
self.use_token_type_ids = True
self.use_labels = True
self.vocab_size = 99
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.scope = None
self.attention_window = 4
# `ModelTesterMixin.test_attention_outputs` is expecting attention tensors to be of size
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = LongformerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
attention_window=self.attention_window,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_attention_mask_determinism(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLongformerModel(config=config)
attention_mask = tf.ones(input_ids.shape, dtype=tf.dtypes.int32)
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
def create_and_check_longformer_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLongformerModel(config=config)
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLongformerModel(config=config)
half_input_mask_length = shape_list(input_mask)[-1] // 2
global_attention_mask = tf.concat(
[
tf.zeros_like(input_mask)[:, :half_input_mask_length],
tf.ones_like(input_mask)[:, half_input_mask_length:],
],
axis=-1,
)
sequence_output, pooled_output = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
)
sequence_output, pooled_output = model(
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
)
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
shape_list(result["sequence_output"]), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(shape_list(result["pooled_output"]), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLongformerForMaskedLM(config=config)
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
shape_list(result["prediction_scores"]), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFLongformerForQuestionAnswering(config=config)
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(shape_list(result["start_logits"]), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result["end_logits"]), [self.batch_size, self.seq_length])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
# global attention mask has to be partly defined
# to trace all weights
global_attention_mask = tf.concat(
[tf.zeros_like(input_ids)[:, :-1], tf.ones_like(input_ids)[:, -1:]], axis=-1,
)
inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
"global_attention_mask": global_attention_mask,
}
return config, inputs_dict
def prepare_config_and_inputs_for_question_answering(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
# Replace sep_token_id by some random id
input_ids = tf.where(input_ids == config.sep_token_id, 0, input_ids)
# Make sure there are exactly three sep_token_id
input_ids = tf.concat([input_ids[:, :-3], tf.ones_like(input_ids)[:, -3:] * config.sep_token_id], axis=-1)
input_mask = tf.ones_like(input_ids)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@require_tf
class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (
(TFLongformerModel, TFLongformerForMaskedLM, TFLongformerForQuestionAnswering,) if is_tf_available() else ()
)
def setUp(self):
self.model_tester = TFLongformerModelTester(self)
self.config_tester = ConfigTester(self, config_class=LongformerConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_longformer_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
@require_tf
class TFLongformerModelIntegrationTest(unittest.TestCase):
def _get_hidden_states(self):
return tf.convert_to_tensor(
[
[
[
4.98332758e-01,
2.69175139e00,
-7.08081422e-03,
1.04915401e00,
-1.83476661e00,
7.67220476e-01,
2.98580543e-01,
2.84803992e-02,
],
[
-7.58357372e-01,
4.20635998e-01,
-4.04739919e-02,
1.59924145e-01,
2.05135748e00,
-1.15997978e00,
5.37166397e-01,
2.62873606e-01,
],
[
-1.69438001e00,
4.17574660e-01,
-1.49196962e00,
-1.76483717e00,
-1.94566312e-01,
-1.71183858e00,
7.72903565e-01,
-1.11557056e00,
],
[
5.44028163e-01,
2.05466114e-01,
-3.63045868e-01,
2.41865062e-01,
3.20348382e-01,
-9.05611176e-01,
-1.92690727e-01,
-1.19917547e00,
],
]
],
dtype=tf.float32,
)
def test_diagonalize(self):
hidden_states = self._get_hidden_states()
hidden_states = tf.reshape(hidden_states, (1, 8, 4)) # set seq length = 8, hidden dim = 4
chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
window_overlap_size = shape_list(chunked_hidden_states)[2]
self.assertTrue(window_overlap_size == 4)
padded_hidden_states = TFLongformerSelfAttention._pad_and_diagonalize(chunked_hidden_states)
self.assertTrue(
shape_list(padded_hidden_states)[-1] == shape_list(chunked_hidden_states)[-1] + window_overlap_size - 1
)
# first row => [0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000]
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, :4], chunked_hidden_states[0, 0, 0], rtol=1e-3)
tf.debugging.assert_near(padded_hidden_states[0, 0, 0, 4:], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3)
# last row => [0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629]
tf.debugging.assert_near(padded_hidden_states[0, 0, -1, 3:], chunked_hidden_states[0, 0, -1], rtol=1e-3)
tf.debugging.assert_near(
padded_hidden_states[0, 0, -1, :3], tf.zeros((3,), dtype=tf.dtypes.float32), rtol=1e-3
)
def test_pad_and_transpose_last_two_dims(self):
hidden_states = self._get_hidden_states()
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
)
def test_mask_invalid_locations(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size))
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
hid_states_1 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states, 1)
hid_states_2 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states, 2)
hid_states_3 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states[:, :, :, :3], 2)
hid_states_4 = TFLongformerSelfAttention._mask_invalid_locations(hidden_states[:, :, 2:, :], 2)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_1), tf.dtypes.int32)) == 8)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_2), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_3), tf.dtypes.int32)) == 24)
self.assertTrue(tf.math.reduce_sum(tf.cast(tf.math.is_inf(hid_states_4), tf.dtypes.int32)) == 12)
def test_chunk(self):
hidden_states = self._get_hidden_states()
batch_size = 1
seq_length = 8
hidden_size = 4
hidden_states = tf.reshape(hidden_states, (batch_size, seq_length, hidden_size))
chunked_hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
# expected slices across chunk and seq length dim
expected_slice_along_seq_length = tf.convert_to_tensor([0.4983, -0.7584, -1.6944], dtype=tf.dtypes.float32)
expected_slice_along_chunk = tf.convert_to_tensor([0.4983, -1.8348, -0.7584, 2.0514], dtype=tf.dtypes.float32)
self.assertTrue(shape_list(chunked_hidden_states) == [1, 3, 4, 4])
tf.debugging.assert_near(chunked_hidden_states[0, :, 0, 0], expected_slice_along_seq_length, rtol=1e-3)
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
def test_layer_local_attn(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False)
layer = model.longformer.encoder.layer[0].attention.self_attention
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
expected_slice = tf.convert_to_tensor(
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
)
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3)
def test_layer_global_attn(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False)
layer = model.longformer.encoder.layer[0].attention.self_attention
hidden_states = self._get_hidden_states()
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor(
[-0.06508, -0.039306, 0.030934, -0.03417, -0.00656, -0.01553, -0.02088, -0.04938], dtype=tf.dtypes.float32
)
expected_slice_1 = tf.convert_to_tensor(
[-0.04055, -0.038399, 0.0396, -0.03735, -0.03415, 0.01357, 0.00145, -0.05709], dtype=tf.dtypes.float32
)
tf.debugging.assert_near(output_hidden_states[0, 2], expected_slice_0, rtol=1e-3)
tf.debugging.assert_near(output_hidden_states[1, -2], expected_slice_1, rtol=1e-3)
@slow
def test_inference_no_head(self):
model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096")
# 'Hello world!'
input_ids = tf.convert_to_tensor([[0, 20920, 232, 328, 1437, 2]], dtype=tf.dtypes.int32)
attention_mask = tf.ones(shape_list(input_ids), dtype=tf.dtypes.int32)
output = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]
expected_output_slice = tf.convert_to_tensor(
[0.0549, 0.1087, -0.1119, -0.0368, 0.0250], dtype=tf.dtypes.float32
)
tf.debugging.assert_near(output[0, 0, -5:], expected_output_slice, rtol=1e-3)
tf.debugging.assert_near(output_without_mask[0, 0, -5:], expected_output_slice, rtol=1e-3)
@slow
def test_inference_no_head_long(self):
model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096")
# 'Hello world! ' repeated 1000 times
input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.dtypes.int32)
attention_mask = tf.ones(shape_list(input_ids), dtype=tf.dtypes.int32)
global_attention_mask = tf.zeros(shape_list(input_ids), dtype=tf.dtypes.int32)
# Set global attention on a few random positions
global_attention_mask = tf.tensor_scatter_nd_update(
global_attention_mask, tf.constant([[0, 1], [0, 4], [0, 21]]), tf.constant([1, 1, 1])
)
output = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)[0]
expected_output_sum = tf.constant(74585.875)
expected_output_mean = tf.constant(0.024267)
# assert close
tf.debugging.assert_near(tf.reduce_sum(output), expected_output_sum, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_mean(output), expected_output_mean, rtol=1e-4)
@slow
def test_inference_masked_lm_long(self):
model = TFLongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
# 'Hello world! ' repeated 1000 times
input_ids = tf.convert_to_tensor([[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=tf.dtypes.int32)
loss, prediction_scores = model(input_ids, labels=input_ids)
expected_loss = tf.constant(0.0073798)
expected_prediction_scores_sum = tf.constant(-610476600.0)
expected_prediction_scores_mean = tf.constant(-3.03477)
# assert close
tf.debugging.assert_near(tf.reduce_mean(loss), expected_loss, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_sum(prediction_scores), expected_prediction_scores_sum, rtol=1e-4)
tf.debugging.assert_near(tf.reduce_mean(prediction_scores), expected_prediction_scores_mean, rtol=1e-4)