* create model

* add integration

* save current state

* make integration tests pass

* add one more test

* add explanation to tests

* remove from bart

* add padding

* remove unnecessary test

* make all tests pass

* re-add cookie cutter tests

* finish PyTorch

* fix attention test

* Update tests/test_modeling_common.py

* revert change

* remove unused file

* add string to doc

* save intermediate

* make tf integration tests pass

* finish tf

* fix doc

* fix docs again

* add led to doctree

* add to auto tokenizer

* added tips for led

* make style

* apply jplus statements

* correct tf longformer

* apply lysandres suggestions

* apply sylvains suggestions

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen 2021-01-05 13:14:30 +01:00 committed by GitHub
parent 314cca2842
commit 189387e9b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 6242 additions and 40 deletions

View File

@ -240,6 +240,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
@ -371,6 +373,7 @@ TensorFlow and/or Flax.
model_doc/fsmt
model_doc/funnel
model_doc/layoutlm
model_doc/led
model_doc/longformer
model_doc/lxmert
model_doc/marian

View File

@ -0,0 +1,150 @@
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
LED
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The LED model was proposed in `Longformer: The Long-Document Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz
Beltagy, Matthew E. Peters, Arman Cohan.
The abstract from the paper is the following:
*Transformer-based models are unable to process long sequences due to their self-attention operation, which scales
quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention
mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or
longer. Longformer's attention mechanism is a drop-in replacement for the standard self-attention and combines a local
windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we
evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In
contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our
pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on
WikiHop and TriviaQA. We finally introduce the Longformer-Encoder-Decoder (LED), a Longformer variant for supporting
long document generative sequence-to-sequence tasks, and demonstrate its effectiveness on the arXiv summarization
dataset.*
Tips:
- :class:`~transformers.LEDForConditionalGeneration` is an extension of
:class:`~transformers.BartForConditionalGeneration` exchanging the traditional *self-attention* layer with
*Longformer*'s *chunked self-attention* layer. :class:`~transformers.LEDTokenizer` is an alias of
:class:`~transformers.BartTokenizer`.
- LED works very well on long-range *sequence-to-sequence* tasks where the ``input_ids`` largely exceed a length of
1024 tokens.
- LED pads the ``input_ids`` to be a multiple of ``config.attention_window`` if required. Therefore a small speed-up is
gained, when :class:`~transformers.LEDTokenizer` is used with the ``pad_to_multiple_of`` argument.
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
``config.gradient_checkpointing = True``.
- A notebook showing how to evaluate LED, can be accessed `here
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
- A notebook showing how to fine-tune LED, can be accessed `here
<https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing>`__.
LEDConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDConfig
:members:
LEDTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary
LEDTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDTokenizerFast
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary
LED specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.led.modeling_led.LEDEncoderBaseModelOutput
:members:
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqModelOutput
:members:
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqLMOutput
:members:
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqQuestionAnsweringModelOutput
:members:
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDEncoderBaseModelOutput
:members:
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDSeq2SeqModelOutput
:members:
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDSeq2SeqLMOutput
:members:
LEDModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDModel
:members: forward
LEDForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDForConditionalGeneration
:members: forward
LEDForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDForSequenceClassification
:members: forward
LEDForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LEDForQuestionAnswering
:members: forward
TFLEDModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLEDModel
:members: call
TFLEDForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLEDForConditionalGeneration
:members: call

View File

@ -146,6 +146,7 @@ from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, F
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.herbert import HerbertTokenizer
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.marian import MarianConfig
@ -255,6 +256,7 @@ if is_tokenizers_available():
from .models.gpt2 import GPT2TokenizerFast
from .models.herbert import HerbertTokenizerFast
from .models.layoutlm import LayoutLMTokenizerFast
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBartTokenizerFast
@ -299,6 +301,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Modeling
if is_torch_available():
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
@ -507,6 +510,13 @@ if is_torch_available():
LayoutLMForTokenClassification,
LayoutLMModel,
)
from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
)
from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
LongformerForMaskedLM,
@ -695,6 +705,7 @@ else:
# TensorFlow
if is_tf_available():
from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
# Benchmarks
@ -831,6 +842,7 @@ if is_tf_available():
TFGPT2Model,
TFGPT2PreTrainedModel,
)
from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,

View File

@ -615,6 +615,7 @@ SLOW_TO_FAST_CONVERTERS = {
"HerbertTokenizer": HerbertConverter,
"LayoutLMTokenizer": BertConverter,
"LongformerTokenizer": RobertaConverter,
"LEDTokenizer": RobertaConverter,
"LxmertTokenizer": BertConverter,
"MBartTokenizer": MBartConverter,
"MPNetTokenizer": MPNetConverter,

View File

@ -35,6 +35,7 @@ from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTCo
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from ..marian.configuration_marian import MarianConfig
@ -66,6 +67,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
# Add archive maps here
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -105,6 +107,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CONFIG_MAPPING = OrderedDict(
[
# Add configs here
("led", LEDConfig),
("retribert", RetriBertConfig),
("mt5", MT5Config),
("t5", T5Config),
@ -150,6 +153,7 @@ CONFIG_MAPPING = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("led", "LED"),
("retribert", "RetriBERT"),
("t5", "T5"),
("mobilebert", "MobileBERT"),

View File

@ -101,6 +101,12 @@ from ..funnel.modeling_funnel import (
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..layoutlm.modeling_layoutlm import LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
from ..led.modeling_led import (
LEDForConditionalGeneration,
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
)
from ..longformer.modeling_longformer import (
LongformerForMaskedLM,
LongformerForMultipleChoice,
@ -221,6 +227,7 @@ from .configuration_auto import (
FunnelConfig,
GPT2Config,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LxmertConfig,
MarianConfig,
@ -252,6 +259,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(LEDConfig, LEDModel),
(RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model),
(T5Config, T5Model),
@ -327,6 +335,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
# Model with LM heads mapping
(LEDConfig, LEDForConditionalGeneration),
(LayoutLMConfig, LayoutLMForMaskedLM),
(T5Config, T5ForConditionalGeneration),
(DistilBertConfig, DistilBertForMaskedLM),
@ -407,6 +416,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(LEDConfig, LEDForConditionalGeneration),
(MT5Config, MT5ForConditionalGeneration),
(T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration),
@ -424,6 +434,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
(LEDConfig, LEDForSequenceClassification),
(DistilBertConfig, DistilBertForSequenceClassification),
(AlbertConfig, AlbertForSequenceClassification),
(CamembertConfig, CamembertForSequenceClassification),
@ -453,6 +464,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
# Model for Question Answering mapping
(LEDConfig, LEDForQuestionAnswering),
(DistilBertConfig, DistilBertForQuestionAnswering),
(AlbertConfig, AlbertForQuestionAnswering),
(CamembertConfig, CamembertForQuestionAnswering),

View File

@ -90,6 +90,7 @@ from ..funnel.modeling_tf_funnel import (
TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
@ -174,6 +175,7 @@ from .configuration_auto import (
FlaubertConfig,
FunnelConfig,
GPT2Config,
LEDConfig,
LongformerConfig,
LxmertConfig,
MarianConfig,
@ -199,6 +201,7 @@ logger = logging.get_logger(__name__)
TF_MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(LEDConfig, TFLEDModel),
(LxmertConfig, TFLxmertModel),
(MT5Config, TFMT5Model),
(T5Config, TFT5Model),
@ -254,6 +257,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
# Model with LM heads mapping
(LEDConfig, TFLEDForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
@ -317,6 +321,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(LEDConfig, TFLEDForConditionalGeneration),
(MT5Config, TFMT5ForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(MarianConfig, TFMarianMTModel),

View File

@ -36,6 +36,7 @@ from ..funnel.tokenization_funnel import FunnelTokenizer
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
from ..herbert.tokenization_herbert import HerbertTokenizer
from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer
from ..led.tokenization_led import LEDTokenizer
from ..longformer.tokenization_longformer import LongformerTokenizer
from ..lxmert.tokenization_lxmert import LxmertTokenizer
from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer
@ -69,6 +70,7 @@ from .configuration_auto import (
FunnelConfig,
GPT2Config,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LxmertConfig,
MarianConfig,
@ -137,6 +139,7 @@ if is_tokenizers_available():
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast
from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast
from ..led.tokenization_led_fast import LEDTokenizerFast
from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
@ -226,6 +229,7 @@ TOKENIZER_MAPPING = OrderedDict(
(ProphetNetConfig, (ProphetNetTokenizer, None)),
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
(TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
]
)

View File

@ -0,0 +1,38 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from .tokenization_led import LEDTokenizer
if is_tokenizers_available():
from .tokenization_led_fast import LEDTokenizerFast
if is_torch_available():
from .modeling_led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
LEDForConditionalGeneration,
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
LEDPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel

View File

@ -0,0 +1,179 @@
# coding=utf-8
# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LED model configuration """
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
LED_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/config.json",
# See all LED models at https://huggingface.co/models?filter=led
}
class LEDConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.LEDModel`. It is used to
instantiate an LED model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the LED `allenai/led-base-16384
<https://huggingface.co/allenai/led-base-16384>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the LED model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.LEDModel` or :class:`~transformers.TFLEDModel`.
d_model (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of encoder layers.
decoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of decoder layers.
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for classifier.
max_encoder_position_embeddings (:obj:`int`, `optional`, defaults to 16384):
The maximum sequence length that the encoder might ever be used with.
max_decoder_position_embeddings (:obj:`int`, `optional`, defaults to 16384):
The maximum sequence length that the decoder might ever be used with.
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models)
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
>>> from transformers import LEDModel, LEDConfig
>>> # Initializing a LED allenai/led-base-16384 style configuration
>>> configuration = LEDConfig()
>>> # Initializing a model from the allenai/led-base-16384 style configuration
>>> model = LEDModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "led"
def __init__(
self,
vocab_size=50265,
max_encoder_position_embeddings=16384,
max_decoder_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=2,
classifier_dropout=0.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
gradient_checkpointing=False,
attention_window: Union[List[int], int] = 512,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.max_encoder_position_embeddings = max_encoder_position_embeddings
self.max_decoder_position_embeddings = max_decoder_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.attention_window = attention_window
self.gradient_checkpointing = gradient_checkpointing
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self) -> int:
return self.d_model
@property
def attention_probs_dropout_prob(self) -> float:
return self.attention_dropout
@property
def initializer_range(self) -> float:
return self.init_std

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,51 @@
# coding=utf-8
# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LED."""
from ...utils import logging
from ..bart.tokenization_bart import BartTokenizer
logger = logging.get_logger(__name__)
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json",
},
"merges_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt",
},
"tokenizer_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"allenai/led-base-16384": 16384,
}
class LEDTokenizer(BartTokenizer):
"""
Construct a LED tokenizer.
:class:`~transformers.LEDTokenizer` is identical to :class:`~transformers.BartTokenizer` and runs end-to-end
tokenization: punctuation splitting and wordpiece.
Refer to superclass :class:`~transformers.BartTokenizer` for usage examples and documentation concerning
parameters.
"""
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

View File

@ -0,0 +1,53 @@
# coding=utf-8
# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LED."""
from ...utils import logging
from ..bart.tokenization_bart_fast import BartTokenizerFast
from .tokenization_led import LEDTokenizer
logger = logging.get_logger(__name__)
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json",
},
"merges_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt",
},
"tokenizer_file": {
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"allenai/led-base-16384": 16384,
}
class LEDTokenizerFast(BartTokenizerFast):
r"""
Construct a "fast" LED tokenizer (backed by HuggingFace's `tokenizers` library).
:class:`~transformers.LEDTokenizerFast` is identical to :class:`~transformers.BartTokenizerFast` and runs
end-to-end tokenization: punctuation splitting and wordpiece.
Refer to superclass :class:`~transformers.BartTokenizerFast` for usage examples and documentation concerning
parameters.
"""
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
slow_tokenizer_class = LEDTokenizer

View File

@ -549,13 +549,19 @@ class LongformerSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2
def forward(
self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
`attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
:class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
`attention_window` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer.
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to -ve: no attention
The `attention_mask` is changed in :meth:`BertModel.forward` from 0, 1, 2 to -ve: no attention
0: local attention
+ve: global attention
@ -631,17 +637,17 @@ class LongformerSelfAttention(nn.Module):
# free memory
del global_key_attn_scores
local_attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
local_attn_probs = local_attn_probs_fp32.type_as(attn_scores)
# free memory
del local_attn_probs_fp32
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
local_attn_probs = torch.masked_fill(local_attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
attn_probs = attn_probs.type_as(attn_scores)
# free memory
del attn_scores
# apply dropout
local_attn_probs = F.dropout(local_attn_probs, p=self.dropout, training=self.training)
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
@ -650,7 +656,7 @@ class LongformerSelfAttention(nn.Module):
# compute sum of global and local attn
attn_output = self._compute_attn_output_with_global_indices(
value_vectors=value_vectors,
attn_probs=local_attn_probs,
attn_probs=attn_probs,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
@ -658,7 +664,7 @@ class LongformerSelfAttention(nn.Module):
else:
# compute local attn only
attn_output = self._sliding_chunks_matmul_attn_probs_value(
local_attn_probs, value_vectors, self.one_sided_attn_window_size
attn_probs, value_vectors, self.one_sided_attn_window_size
)
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
@ -688,11 +694,14 @@ class LongformerSelfAttention(nn.Module):
# The attention weights for tokens with global attention are
# just filler values, they were never used to compute the output.
# Fill with 0 now, the correct values are in 'global_attn_probs'.
local_attn_probs[is_index_global_attn_nonzero] = 0
attn_probs[is_index_global_attn_nonzero] = 0
outputs = (attn_output.transpose(0, 1), local_attn_probs)
outputs = (attn_output.transpose(0, 1),)
return outputs + (global_attn_probs,) if is_global_attn else outputs
if output_attentions:
outputs += (attn_probs,)
return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs
@staticmethod
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
@ -711,6 +720,7 @@ class LongformerSelfAttention(nn.Module):
shift every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
@ -728,13 +738,13 @@ class LongformerSelfAttention(nn.Module):
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, -1
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap
chunked_hidden_states = chunked_hidden_states[
:, :, :-window_overlap
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
] # total_num_heads x num_chunks x window_overlap*window_overlap
chunked_hidden_states = chunked_hidden_states.view(
total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
)
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
return chunked_hidden_states
@ -788,18 +798,18 @@ class LongformerSelfAttention(nn.Module):
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
chunked_query = self._chunk(query, window_overlap)
chunked_key = self._chunk(key, window_overlap)
query = self._chunk(query, window_overlap)
key = self._chunk(key, window_overlap)
# matrix multiplication
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
# bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap
chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply
diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
# convert diagonals into columns
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
chunked_attention_scores, padding=(0, 0, 0, 1)
diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
@ -1095,7 +1105,13 @@ class LongformerAttention(nn.Module):
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
@ -1103,6 +1119,7 @@ class LongformerAttention(nn.Module):
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
)
attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:]
@ -1150,7 +1167,13 @@ class LongformerLayer(nn.Module):
self.seq_len_dim = 1
def forward(
self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
self_attn_outputs = self.attention(
hidden_states,
@ -1158,6 +1181,7 @@ class LongformerLayer(nn.Module):
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
)
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:]
@ -1205,7 +1229,7 @@ class LongformerEncoder(nn.Module):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, is_global_attn)
return module(*inputs, is_global_attn, output_attentions)
return custom_forward
@ -1223,6 +1247,7 @@ class LongformerEncoder(nn.Module):
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]

View File

@ -796,7 +796,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# normalize query
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32))
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
@ -945,7 +945,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
# convert diagonals into columns
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
@ -1093,7 +1093,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
)
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
padded_value = tf.pad(value, paddings, constant_values=-1)
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
@ -1141,6 +1141,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
shift every row 1 step right, converting columns into diagonals.
Example::
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599,
@ -1153,7 +1154,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
"""
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
chunked_hidden_states = tf.pad(
chunked_hidden_states, paddings
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
@ -1349,7 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
global_value_vectors = self.value_global(hidden_states)
# normalize
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32))
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
@ -1820,10 +1821,10 @@ class TFLongformerPreTrainedModel(TFPreTrainedModel):
@property
def dummy_inputs(self):
input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
input_ids = tf.convert_to_tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
# make sure global layers are initialized
attention_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
global_attention_mask = tf.constant([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
attention_mask = tf.convert_to_tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
global_attention_mask = tf.convert_to_tensor([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
@ -2371,9 +2372,9 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
@property
def dummy_inputs(self):
input_ids = tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)
input_ids = tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)
# make sure global layers are initialized
global_attention_mask = tf.constant([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
global_attention_mask = tf.convert_to_tensor([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
return {"input_ids": input_ids, "global_attention_mask": global_attention_mask}
@add_start_docstrings_to_model_forward(

View File

@ -1179,6 +1179,45 @@ class LayoutLMModel:
requires_pytorch(self)
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
class LEDForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class LEDForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class LEDForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class LEDModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -827,6 +827,33 @@ class TFGPT2PreTrainedModel:
requires_tf(self)
class TFLEDForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLEDModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLEDPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -128,6 +128,15 @@ class LayoutLMTokenizerFast:
requires_tokenizers(self)
class LEDTokenizerFast:
def __init__(self, *args, **kwargs):
requires_tokenizers(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tokenizers(self)
class LongformerTokenizerFast:
def __init__(self, *args, **kwargs):
requires_tokenizers(self)

511
tests/test_modeling_led.py Normal file

File diff suppressed because one or more lines are too long

View File

@ -488,13 +488,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, _ = layer(
output_hidden_states = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
)[0]
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
self.assertTrue(
@ -526,13 +526,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, _, _ = layer(
output_hidden_states = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@ -583,6 +583,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
output_attentions=True,
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))

View File

@ -0,0 +1,348 @@
# coding=utf-8
# Copyright Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import LEDConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFLEDForConditionalGeneration, TFLEDModel
@require_tf
class TFLEDModelTester:
config_cls = LEDConfig
config_updates = {}
hidden_act = "gelu"
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
attention_window=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.attention_window = attention_window
# `ModelTesterMixin.test_attention_outputs` is expecting attention tensors to be of size
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window` and one before and one after
self.key_length = self.attention_window + 1
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
self.encoder_seq_length = (
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
)
def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = self.config_cls(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[2],
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.pad_token_id,
attention_window=self.attention_window,
**self.config_updates,
)
inputs_dict = prepare_led_inputs_dict(config, input_ids, decoder_input_ids)
global_attention_mask = tf.concat(
[tf.zeros_like(input_ids)[:, :-1], tf.ones_like(input_ids)[:, -1:]],
axis=-1,
)
inputs_dict["global_attention_mask"] = global_attention_mask
return config, inputs_dict
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = TFLEDModel(config=config).get_decoder()
input_ids = inputs_dict["input_ids"]
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def prepare_led_inputs_dict(
config,
input_ids,
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
if decoder_attention_mask is None:
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
@require_tf
class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFLEDForConditionalGeneration, TFLEDModel) if is_tf_available() else ()
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = TFLEDModelTester(self)
self.config_tester = ConfigTester(self, config_class=LEDConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_layer_with_bias()
assert x is None
name = model.get_prefix_bias_name()
assert name is None
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["global_attention_mask"] = tf.zeros_like(inputs_dict["attention_mask"])
num_global_attn_indices = 2
inputs_dict["global_attention_mask"] = tf.where(
tf.range(self.model_tester.seq_length)[None, :] < num_global_attn_indices,
1,
inputs_dict["global_attention_mask"],
)
config.return_dict = True
seq_length = self.model_tester.seq_length
encoder_seq_length = self.model_tester.encoder_seq_length
def check_decoder_attentions_output(outputs):
decoder_attentions = outputs.decoder_attentions
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
def check_encoder_attentions_output(outputs):
attentions = [t.numpy() for t in outputs.encoder_attentions]
global_attentions = [t.numpy() for t in outputs.encoder_global_attentions]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length],
)
self.assertListEqual(
list(global_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, num_global_attn_indices],
)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["use_cache"] = False
config.output_hidden_states = False
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
out_len = len(outputs)
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)
if self.is_encoder_decoder:
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(config.output_hidden_states, False)
check_decoder_attentions_output(outputs)
# Check that output attentions can also be changed via the config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(config.output_hidden_states, False)
check_encoder_attentions_output(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs)
@slow
def test_saved_model_with_attentions_output(self):
# longformer has special attentions which are not
# compatible in graph mode
pass
@slow
def test_saved_model_with_hidden_states_output(self):
# TODO(JPLU, PVP) this test should pass!!! PVP:
# IMO there is a problem with the signature check.
# Test passes for TFLEDModel, but not for TFLEDForConditionalGeneration
# IMO the reason is that the tensor variable name cannot be changed
# from decoder_input_ids -> input_ids, which poses a BIG restrictions
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if tf.debugging.assert_near(a, b, atol=atol):
return True
raise
except Exception:
msg = "{} != {}".format(a, b)
if prefix:
msg = prefix + ": " + msg
raise AssertionError(msg)
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)
TOLERANCE = 1e-4
@slow
@require_tf
class TFLEDModelIntegrationTest(unittest.TestCase):
def test_inference_no_head(self):
model = TFLEDForConditionalGeneration.from_pretrained("allenai/led-base-16384").led
# change to intended input here
input_ids = _long_tensor([512 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
decoder_input_ids = _long_tensor([128 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
inputs_dict = prepare_led_inputs_dict(model.config, input_ids, decoder_input_ids)
output = model(**inputs_dict)[0]
expected_shape = (1, 1024, 768)
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = tf.convert_to_tensor(
[[2.3050, 2.8279, 0.6531], [-1.8457, -0.1455, -3.5661], [-1.0186, 0.4586, -2.2043]],
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
def test_inference_with_head(self):
model = TFLEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
# change to intended input here
input_ids = _long_tensor([512 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
decoder_input_ids = _long_tensor([128 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
inputs_dict = prepare_led_inputs_dict(model.config, input_ids, decoder_input_ids)
output = model(**inputs_dict)[0]
expected_shape = (1, 1024, model.config.vocab_size)
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = tf.convert_to_tensor(
[[33.6507, 6.4572, 16.8089], [5.8739, -2.4238, 11.2902], [-3.2139, -4.3149, 4.2783]],
)
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)

View File

@ -339,6 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
@slow
def test_saved_model_with_attentions_output(self):
# longformer has special attentions which are not
# compatible in graph mode
pass

View File

@ -30,6 +30,8 @@ PATH_TO_DOC = "docs/source"
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [
# models to ignore for not tested
"LEDEncoder", # Building part of bigger (tested) model.
"LEDDecoder", # Building part of bigger (tested) model.
"BartDecoder", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder.
@ -64,6 +66,8 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
# models to ignore for model xxx mapping
"LEDEncoder",
"LEDDecoder",
"BartDecoder",
"BartEncoder",
"DPRContextEncoder",