mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
LED (#9278)
* create model * add integration * save current state * make integration tests pass * add one more test * add explanation to tests * remove from bart * add padding * remove unnecessary test * make all tests pass * re-add cookie cutter tests * finish PyTorch * fix attention test * Update tests/test_modeling_common.py * revert change * remove unused file * add string to doc * save intermediate * make tf integration tests pass * finish tf * fix doc * fix docs again * add led to doctree * add to auto tokenizer * added tips for led * make style * apply jplus statements * correct tf longformer * apply lysandres suggestions * apply sylvains suggestions * Apply suggestions from code review
This commit is contained in:
parent
314cca2842
commit
189387e9b2
@ -240,6 +240,8 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
@ -371,6 +373,7 @@ TensorFlow and/or Flax.
|
||||
model_doc/fsmt
|
||||
model_doc/funnel
|
||||
model_doc/layoutlm
|
||||
model_doc/led
|
||||
model_doc/longformer
|
||||
model_doc/lxmert
|
||||
model_doc/marian
|
||||
|
150
docs/source/model_doc/led.rst
Normal file
150
docs/source/model_doc/led.rst
Normal file
@ -0,0 +1,150 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
LED
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The LED model was proposed in `Longformer: The Long-Document Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz
|
||||
Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformer-based models are unable to process long sequences due to their self-attention operation, which scales
|
||||
quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention
|
||||
mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or
|
||||
longer. Longformer's attention mechanism is a drop-in replacement for the standard self-attention and combines a local
|
||||
windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we
|
||||
evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In
|
||||
contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our
|
||||
pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on
|
||||
WikiHop and TriviaQA. We finally introduce the Longformer-Encoder-Decoder (LED), a Longformer variant for supporting
|
||||
long document generative sequence-to-sequence tasks, and demonstrate its effectiveness on the arXiv summarization
|
||||
dataset.*
|
||||
|
||||
Tips:
|
||||
|
||||
- :class:`~transformers.LEDForConditionalGeneration` is an extension of
|
||||
:class:`~transformers.BartForConditionalGeneration` exchanging the traditional *self-attention* layer with
|
||||
*Longformer*'s *chunked self-attention* layer. :class:`~transformers.LEDTokenizer` is an alias of
|
||||
:class:`~transformers.BartTokenizer`.
|
||||
- LED works very well on long-range *sequence-to-sequence* tasks where the ``input_ids`` largely exceed a length of
|
||||
|
||||
1024 tokens.
|
||||
- LED pads the ``input_ids`` to be a multiple of ``config.attention_window`` if required. Therefore a small speed-up is
|
||||
gained, when :class:`~transformers.LEDTokenizer` is used with the ``pad_to_multiple_of`` argument.
|
||||
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
|
||||
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
|
||||
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
|
||||
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
|
||||
``config.gradient_checkpointing = True``.
|
||||
- A notebook showing how to evaluate LED, can be accessed `here
|
||||
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
|
||||
- A notebook showing how to fine-tune LED, can be accessed `here
|
||||
<https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing>`__.
|
||||
|
||||
|
||||
LEDConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDConfig
|
||||
:members:
|
||||
|
||||
|
||||
LEDTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
LEDTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDTokenizerFast
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
LED specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDEncoderBaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqLMOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqSequenceClassifierOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqQuestionAnsweringModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDEncoderBaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDSeq2SeqModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDSeq2SeqLMOutput
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
|
||||
LEDModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDModel
|
||||
:members: forward
|
||||
|
||||
|
||||
LEDForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDForConditionalGeneration
|
||||
:members: forward
|
||||
|
||||
|
||||
LEDForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
LEDForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
TFLEDModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLEDModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFLEDForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLEDForConditionalGeneration
|
||||
:members: call
|
@ -146,6 +146,7 @@ from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, F
|
||||
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
|
||||
from .models.herbert import HerbertTokenizer
|
||||
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
|
||||
from .models.marian import MarianConfig
|
||||
@ -255,6 +256,7 @@ if is_tokenizers_available():
|
||||
from .models.gpt2 import GPT2TokenizerFast
|
||||
from .models.herbert import HerbertTokenizerFast
|
||||
from .models.layoutlm import LayoutLMTokenizerFast
|
||||
from .models.led import LEDTokenizerFast
|
||||
from .models.longformer import LongformerTokenizerFast
|
||||
from .models.lxmert import LxmertTokenizerFast
|
||||
from .models.mbart import MBartTokenizerFast
|
||||
@ -299,6 +301,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
|
||||
# Benchmarks
|
||||
from .benchmark.benchmark import PyTorchBenchmark
|
||||
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
|
||||
@ -507,6 +510,13 @@ if is_torch_available():
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
)
|
||||
from .models.led import (
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LEDForConditionalGeneration,
|
||||
LEDForQuestionAnswering,
|
||||
LEDForSequenceClassification,
|
||||
LEDModel,
|
||||
)
|
||||
from .models.longformer import (
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LongformerForMaskedLM,
|
||||
@ -695,6 +705,7 @@ else:
|
||||
|
||||
# TensorFlow
|
||||
if is_tf_available():
|
||||
|
||||
from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
|
||||
|
||||
# Benchmarks
|
||||
@ -831,6 +842,7 @@ if is_tf_available():
|
||||
TFGPT2Model,
|
||||
TFGPT2PreTrainedModel,
|
||||
)
|
||||
from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
|
||||
from .models.longformer import (
|
||||
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFLongformerForMaskedLM,
|
||||
|
@ -615,6 +615,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"HerbertTokenizer": HerbertConverter,
|
||||
"LayoutLMTokenizer": BertConverter,
|
||||
"LongformerTokenizer": RobertaConverter,
|
||||
"LEDTokenizer": RobertaConverter,
|
||||
"LxmertTokenizer": BertConverter,
|
||||
"MBartTokenizer": MBartConverter,
|
||||
"MPNetTokenizer": MPNetConverter,
|
||||
|
@ -35,6 +35,7 @@ from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTCo
|
||||
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
|
||||
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
|
||||
from ..marian.configuration_marian import MarianConfig
|
||||
@ -66,6 +67,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
# Add archive maps here
|
||||
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -105,6 +107,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("led", LEDConfig),
|
||||
("retribert", RetriBertConfig),
|
||||
("mt5", MT5Config),
|
||||
("t5", T5Config),
|
||||
@ -150,6 +153,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("led", "LED"),
|
||||
("retribert", "RetriBERT"),
|
||||
("t5", "T5"),
|
||||
("mobilebert", "MobileBERT"),
|
||||
|
@ -101,6 +101,12 @@ from ..funnel.modeling_funnel import (
|
||||
)
|
||||
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
|
||||
from ..layoutlm.modeling_layoutlm import LayoutLMForMaskedLM, LayoutLMForTokenClassification, LayoutLMModel
|
||||
from ..led.modeling_led import (
|
||||
LEDForConditionalGeneration,
|
||||
LEDForQuestionAnswering,
|
||||
LEDForSequenceClassification,
|
||||
LEDModel,
|
||||
)
|
||||
from ..longformer.modeling_longformer import (
|
||||
LongformerForMaskedLM,
|
||||
LongformerForMultipleChoice,
|
||||
@ -221,6 +227,7 @@ from .configuration_auto import (
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
LxmertConfig,
|
||||
MarianConfig,
|
||||
@ -252,6 +259,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(LEDConfig, LEDModel),
|
||||
(RetriBertConfig, RetriBertModel),
|
||||
(MT5Config, MT5Model),
|
||||
(T5Config, T5Model),
|
||||
@ -327,6 +335,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
(LEDConfig, LEDForConditionalGeneration),
|
||||
(LayoutLMConfig, LayoutLMForMaskedLM),
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(DistilBertConfig, DistilBertForMaskedLM),
|
||||
@ -407,6 +416,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(LEDConfig, LEDForConditionalGeneration),
|
||||
(MT5Config, MT5ForConditionalGeneration),
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
@ -424,6 +434,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(LEDConfig, LEDForSequenceClassification),
|
||||
(DistilBertConfig, DistilBertForSequenceClassification),
|
||||
(AlbertConfig, AlbertForSequenceClassification),
|
||||
(CamembertConfig, CamembertForSequenceClassification),
|
||||
@ -453,6 +464,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(LEDConfig, LEDForQuestionAnswering),
|
||||
(DistilBertConfig, DistilBertForQuestionAnswering),
|
||||
(AlbertConfig, AlbertForQuestionAnswering),
|
||||
(CamembertConfig, CamembertForQuestionAnswering),
|
||||
|
@ -90,6 +90,7 @@ from ..funnel.modeling_tf_funnel import (
|
||||
TFFunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
|
||||
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
|
||||
from ..longformer.modeling_tf_longformer import (
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
@ -174,6 +175,7 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
LxmertConfig,
|
||||
MarianConfig,
|
||||
@ -199,6 +201,7 @@ logger = logging.get_logger(__name__)
|
||||
TF_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(LEDConfig, TFLEDModel),
|
||||
(LxmertConfig, TFLxmertModel),
|
||||
(MT5Config, TFMT5Model),
|
||||
(T5Config, TFT5Model),
|
||||
@ -254,6 +257,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
(LEDConfig, TFLEDForConditionalGeneration),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
@ -317,6 +321,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(LEDConfig, TFLEDForConditionalGeneration),
|
||||
(MT5Config, TFMT5ForConditionalGeneration),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(MarianConfig, TFMarianMTModel),
|
||||
|
@ -36,6 +36,7 @@ from ..funnel.tokenization_funnel import FunnelTokenizer
|
||||
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from ..herbert.tokenization_herbert import HerbertTokenizer
|
||||
from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer
|
||||
from ..led.tokenization_led import LEDTokenizer
|
||||
from ..longformer.tokenization_longformer import LongformerTokenizer
|
||||
from ..lxmert.tokenization_lxmert import LxmertTokenizer
|
||||
from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer
|
||||
@ -69,6 +70,7 @@ from .configuration_auto import (
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
LxmertConfig,
|
||||
MarianConfig,
|
||||
@ -137,6 +139,7 @@ if is_tokenizers_available():
|
||||
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
|
||||
from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast
|
||||
from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast
|
||||
from ..led.tokenization_led_fast import LEDTokenizerFast
|
||||
from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast
|
||||
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
|
||||
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
|
||||
@ -226,6 +229,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(ProphetNetConfig, (ProphetNetTokenizer, None)),
|
||||
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
|
||||
(TapasConfig, (TapasTokenizer, None)),
|
||||
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
|
||||
]
|
||||
)
|
||||
|
||||
|
38
src/transformers/models/led/__init__.py
Normal file
38
src/transformers/models/led/__init__.py
Normal file
@ -0,0 +1,38 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...file_utils import is_tf_available, is_tokenizers_available, is_torch_available
|
||||
from .configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
from .tokenization_led import LEDTokenizer
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_led_fast import LEDTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_led import (
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LEDForConditionalGeneration,
|
||||
LEDForQuestionAnswering,
|
||||
LEDForSequenceClassification,
|
||||
LEDModel,
|
||||
LEDPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel
|
179
src/transformers/models/led/configuration_led.py
Normal file
179
src/transformers/models/led/configuration_led.py
Normal file
@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" LED model configuration """
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LED_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/config.json",
|
||||
# See all LED models at https://huggingface.co/models?filter=led
|
||||
}
|
||||
|
||||
|
||||
class LEDConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.LEDModel`. It is used to
|
||||
instantiate an LED model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the LED `allenai/led-base-16384
|
||||
<https://huggingface.co/allenai/led-base-16384>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||
Vocabulary size of the LED model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.LEDModel` or :class:`~transformers.TFLEDModel`.
|
||||
d_model (:obj:`int`, `optional`, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
||||
dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_encoder_position_embeddings (:obj:`int`, `optional`, defaults to 16384):
|
||||
The maximum sequence length that the encoder might ever be used with.
|
||||
max_decoder_position_embeddings (:obj:`int`, `optional`, defaults to 16384):
|
||||
The maximum sequence length that the decoder might ever be used with.
|
||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
|
||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
|
||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models)
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import LEDModel, LEDConfig
|
||||
|
||||
>>> # Initializing a LED allenai/led-base-16384 style configuration
|
||||
>>> configuration = LEDConfig()
|
||||
|
||||
>>> # Initializing a model from the allenai/led-base-16384 style configuration
|
||||
>>> model = LEDModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "led"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
max_encoder_position_embeddings=16384,
|
||||
max_decoder_position_embeddings=1024,
|
||||
encoder_layers=12,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_layers=12,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
use_cache=True,
|
||||
is_encoder_decoder=True,
|
||||
activation_function="gelu",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
decoder_start_token_id=2,
|
||||
classifier_dropout=0.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
gradient_checkpointing=False,
|
||||
attention_window: Union[List[int], int] = 512,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_encoder_position_embeddings = max_encoder_position_embeddings
|
||||
self.max_decoder_position_embeddings = max_decoder_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.attention_window = attention_window
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def attention_probs_dropout_prob(self) -> float:
|
||||
return self.attention_dropout
|
||||
|
||||
@property
|
||||
def initializer_range(self) -> float:
|
||||
return self.init_std
|
2476
src/transformers/models/led/modeling_led.py
Executable file
2476
src/transformers/models/led/modeling_led.py
Executable file
File diff suppressed because it is too large
Load Diff
2247
src/transformers/models/led/modeling_tf_led.py
Normal file
2247
src/transformers/models/led/modeling_tf_led.py
Normal file
File diff suppressed because it is too large
Load Diff
51
src/transformers/models/led/tokenization_led.py
Normal file
51
src/transformers/models/led/tokenization_led.py
Normal file
@ -0,0 +1,51 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for LED."""
|
||||
from ...utils import logging
|
||||
from ..bart.tokenization_bart import BartTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json",
|
||||
},
|
||||
"merges_file": {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"allenai/led-base-16384": 16384,
|
||||
}
|
||||
|
||||
|
||||
class LEDTokenizer(BartTokenizer):
|
||||
"""
|
||||
Construct a LED tokenizer.
|
||||
|
||||
:class:`~transformers.LEDTokenizer` is identical to :class:`~transformers.BartTokenizer` and runs end-to-end
|
||||
tokenization: punctuation splitting and wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BartTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
53
src/transformers/models/led/tokenization_led_fast.py
Normal file
53
src/transformers/models/led/tokenization_led_fast.py
Normal file
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for LED."""
|
||||
from ...utils import logging
|
||||
from ..bart.tokenization_bart_fast import BartTokenizerFast
|
||||
from .tokenization_led import LEDTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/vocab.json",
|
||||
},
|
||||
"merges_file": {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/merges.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"allenai/led-base-16384": "https://huggingface.co/allenai/led-base-16384/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"allenai/led-base-16384": 16384,
|
||||
}
|
||||
|
||||
|
||||
class LEDTokenizerFast(BartTokenizerFast):
|
||||
r"""
|
||||
Construct a "fast" LED tokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.LEDTokenizerFast` is identical to :class:`~transformers.BartTokenizerFast` and runs
|
||||
end-to-end tokenization: punctuation splitting and wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BartTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
slow_tokenizer_class = LEDTokenizer
|
@ -549,13 +549,19 @@ class LongformerSelfAttention(nn.Module):
|
||||
self.one_sided_attn_window_size = attention_window // 2
|
||||
|
||||
def forward(
|
||||
self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
"""
|
||||
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
|
||||
`attention_window` happens in LongformerModel.forward to avoid redoing the padding on each layer.
|
||||
:class:`LongformerSelfAttention` expects `len(hidden_states)` to be multiple of `attention_window`. Padding to
|
||||
`attention_window` happens in :meth:`LongformerModel.forward` to avoid redoing the padding on each layer.
|
||||
|
||||
The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to -ve: no attention
|
||||
The `attention_mask` is changed in :meth:`BertModel.forward` from 0, 1, 2 to -ve: no attention
|
||||
|
||||
0: local attention
|
||||
+ve: global attention
|
||||
@ -631,17 +637,17 @@ class LongformerSelfAttention(nn.Module):
|
||||
# free memory
|
||||
del global_key_attn_scores
|
||||
|
||||
local_attn_probs_fp32 = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
||||
local_attn_probs = local_attn_probs_fp32.type_as(attn_scores)
|
||||
|
||||
# free memory
|
||||
del local_attn_probs_fp32
|
||||
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
|
||||
|
||||
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
|
||||
local_attn_probs = torch.masked_fill(local_attn_probs, is_index_masked[:, :, None, None], 0.0)
|
||||
attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
|
||||
attn_probs = attn_probs.type_as(attn_scores)
|
||||
|
||||
# free memory
|
||||
del attn_scores
|
||||
|
||||
# apply dropout
|
||||
local_attn_probs = F.dropout(local_attn_probs, p=self.dropout, training=self.training)
|
||||
attn_probs = F.dropout(attn_probs, p=self.dropout, training=self.training)
|
||||
|
||||
value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
@ -650,7 +656,7 @@ class LongformerSelfAttention(nn.Module):
|
||||
# compute sum of global and local attn
|
||||
attn_output = self._compute_attn_output_with_global_indices(
|
||||
value_vectors=value_vectors,
|
||||
attn_probs=local_attn_probs,
|
||||
attn_probs=attn_probs,
|
||||
max_num_global_attn_indices=max_num_global_attn_indices,
|
||||
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
|
||||
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
|
||||
@ -658,7 +664,7 @@ class LongformerSelfAttention(nn.Module):
|
||||
else:
|
||||
# compute local attn only
|
||||
attn_output = self._sliding_chunks_matmul_attn_probs_value(
|
||||
local_attn_probs, value_vectors, self.one_sided_attn_window_size
|
||||
attn_probs, value_vectors, self.one_sided_attn_window_size
|
||||
)
|
||||
|
||||
assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
|
||||
@ -688,11 +694,14 @@ class LongformerSelfAttention(nn.Module):
|
||||
# The attention weights for tokens with global attention are
|
||||
# just filler values, they were never used to compute the output.
|
||||
# Fill with 0 now, the correct values are in 'global_attn_probs'.
|
||||
local_attn_probs[is_index_global_attn_nonzero] = 0
|
||||
attn_probs[is_index_global_attn_nonzero] = 0
|
||||
|
||||
outputs = (attn_output.transpose(0, 1), local_attn_probs)
|
||||
outputs = (attn_output.transpose(0, 1),)
|
||||
|
||||
return outputs + (global_attn_probs,) if is_global_attn else outputs
|
||||
if output_attentions:
|
||||
outputs += (attn_probs,)
|
||||
|
||||
return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs
|
||||
|
||||
@staticmethod
|
||||
def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
|
||||
@ -711,6 +720,7 @@ class LongformerSelfAttention(nn.Module):
|
||||
shift every row 1 step right, converting columns into diagonals.
|
||||
|
||||
Example::
|
||||
|
||||
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
|
||||
-1.8348, 0.7672, 0.2986, 0.0285,
|
||||
-0.7584, 0.4206, -0.0405, 0.1599,
|
||||
@ -728,13 +738,13 @@ class LongformerSelfAttention(nn.Module):
|
||||
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
|
||||
chunked_hidden_states = chunked_hidden_states.view(
|
||||
total_num_heads, num_chunks, -1
|
||||
) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
|
||||
) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap
|
||||
chunked_hidden_states = chunked_hidden_states[
|
||||
:, :, :-window_overlap
|
||||
] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
|
||||
] # total_num_heads x num_chunks x window_overlap*window_overlap
|
||||
chunked_hidden_states = chunked_hidden_states.view(
|
||||
total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
|
||||
) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
|
||||
)
|
||||
chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
|
||||
return chunked_hidden_states
|
||||
|
||||
@ -788,18 +798,18 @@ class LongformerSelfAttention(nn.Module):
|
||||
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
||||
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
||||
|
||||
chunked_query = self._chunk(query, window_overlap)
|
||||
chunked_key = self._chunk(key, window_overlap)
|
||||
query = self._chunk(query, window_overlap)
|
||||
key = self._chunk(key, window_overlap)
|
||||
|
||||
# matrix multiplication
|
||||
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
|
||||
# bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
|
||||
# bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap
|
||||
chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunked_query, chunked_key)) # multiply
|
||||
diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
|
||||
|
||||
# convert diagonals into columns
|
||||
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
|
||||
chunked_attention_scores, padding=(0, 0, 0, 1)
|
||||
diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
|
||||
)
|
||||
|
||||
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
|
||||
@ -1095,7 +1105,13 @@ class LongformerAttention(nn.Module):
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
@ -1103,6 +1119,7 @@ class LongformerAttention(nn.Module):
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attn_output,) + self_outputs[1:]
|
||||
@ -1150,7 +1167,13 @@ class LongformerLayer(nn.Module):
|
||||
self.seq_len_dim = 1
|
||||
|
||||
def forward(
|
||||
self, hidden_states, attention_mask=None, is_index_masked=None, is_index_global_attn=None, is_global_attn=None
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
self_attn_outputs = self.attention(
|
||||
hidden_states,
|
||||
@ -1158,6 +1181,7 @@ class LongformerLayer(nn.Module):
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = self_attn_outputs[0]
|
||||
outputs = self_attn_outputs[1:]
|
||||
@ -1205,7 +1229,7 @@ class LongformerEncoder(nn.Module):
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, is_global_attn)
|
||||
return module(*inputs, is_global_attn, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@ -1223,6 +1247,7 @@ class LongformerEncoder(nn.Module):
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
|
@ -796,7 +796,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# normalize query
|
||||
query_vectors /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||
query_vectors /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32))
|
||||
query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
|
||||
@ -945,7 +945,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
|
||||
|
||||
# convert diagonals into columns
|
||||
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
|
||||
|
||||
# allocate space for the overall attention matrix where the chunks are combined. The last dimension
|
||||
@ -1093,7 +1093,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# pad seq_len with w at the beginning of the sequence and another window overlap at the end
|
||||
paddings = tf.constant([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
|
||||
paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]], dtype=tf.dtypes.int32)
|
||||
padded_value = tf.pad(value, paddings, constant_values=-1)
|
||||
|
||||
# chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
|
||||
@ -1141,6 +1141,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
shift every row 1 step right, converting columns into diagonals.
|
||||
|
||||
Example::
|
||||
|
||||
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
|
||||
-1.8348, 0.7672, 0.2986, 0.0285,
|
||||
-0.7584, 0.4206, -0.0405, 0.1599,
|
||||
@ -1153,7 +1154,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
|
||||
"""
|
||||
total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
|
||||
paddings = tf.constant([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
|
||||
paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
|
||||
chunked_hidden_states = tf.pad(
|
||||
chunked_hidden_states, paddings
|
||||
) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
|
||||
@ -1349,7 +1350,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
global_value_vectors = self.value_global(hidden_states)
|
||||
|
||||
# normalize
|
||||
global_query_vectors_only_global /= tf.math.sqrt(tf.constant(self.head_dim, dtype=tf.dtypes.float32))
|
||||
global_query_vectors_only_global /= tf.math.sqrt(tf.convert_to_tensor(self.head_dim, dtype=tf.dtypes.float32))
|
||||
global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
|
||||
global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
|
||||
global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
|
||||
@ -1820,10 +1821,10 @@ class TFLongformerPreTrainedModel(TFPreTrainedModel):
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
input_ids = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
input_ids = tf.convert_to_tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
||||
# make sure global layers are initialized
|
||||
attention_mask = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
global_attention_mask = tf.constant([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
|
||||
attention_mask = tf.convert_to_tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
|
||||
global_attention_mask = tf.convert_to_tensor([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]])
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
@ -2371,9 +2372,9 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
input_ids = tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)
|
||||
input_ids = tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)
|
||||
# make sure global layers are initialized
|
||||
global_attention_mask = tf.constant([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
|
||||
global_attention_mask = tf.convert_to_tensor([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
|
||||
return {"input_ids": input_ids, "global_attention_mask": global_attention_mask}
|
||||
|
||||
@add_start_docstrings_to_model_forward(
|
||||
|
@ -1179,6 +1179,45 @@ class LayoutLMModel:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class LEDForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LEDForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LEDForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LEDModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -827,6 +827,33 @@ class TFGPT2PreTrainedModel:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLEDForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLEDModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFLEDPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -128,6 +128,15 @@ class LayoutLMTokenizerFast:
|
||||
requires_tokenizers(self)
|
||||
|
||||
|
||||
class LEDTokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
|
||||
|
||||
class LongformerTokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
|
511
tests/test_modeling_led.py
Normal file
511
tests/test_modeling_led.py
Normal file
File diff suppressed because one or more lines are too long
@ -488,13 +488,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_global_attn = attention_mask > 0
|
||||
is_global_attn = is_index_global_attn.flatten().any().item()
|
||||
|
||||
output_hidden_states, _ = layer(
|
||||
output_hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
)
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
|
||||
self.assertTrue(
|
||||
@ -526,13 +526,13 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_global_attn = attention_mask > 0
|
||||
is_global_attn = is_index_global_attn.flatten().any().item()
|
||||
|
||||
output_hidden_states, _, _ = layer(
|
||||
output_hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
)
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
|
||||
@ -583,6 +583,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
|
||||
|
348
tests/test_modeling_tf_led.py
Normal file
348
tests/test_modeling_tf_led.py
Normal file
@ -0,0 +1,348 @@
|
||||
# coding=utf-8
|
||||
# Copyright Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import LEDConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFLEDForConditionalGeneration, TFLEDModel
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFLEDModelTester:
|
||||
config_cls = LEDConfig
|
||||
config_updates = {}
|
||||
hidden_act = "gelu"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
attention_window=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.attention_window = attention_window
|
||||
|
||||
# `ModelTesterMixin.test_attention_outputs` is expecting attention tensors to be of size
|
||||
# [num_attention_heads, encoder_seq_length, encoder_key_length], but TFLongformerSelfAttention
|
||||
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
|
||||
# because its local attention only attends to `self.attention_window` and one before and one after
|
||||
self.key_length = self.attention_window + 1
|
||||
|
||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
||||
self.encoder_seq_length = (
|
||||
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = self.config_cls(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
encoder_ffn_dim=self.intermediate_size,
|
||||
decoder_ffn_dim=self.intermediate_size,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_ids=[2],
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.pad_token_id,
|
||||
attention_window=self.attention_window,
|
||||
**self.config_updates,
|
||||
)
|
||||
inputs_dict = prepare_led_inputs_dict(config, input_ids, decoder_input_ids)
|
||||
global_attention_mask = tf.concat(
|
||||
[tf.zeros_like(input_ids)[:, :-1], tf.ones_like(input_ids)[:, -1:]],
|
||||
axis=-1,
|
||||
)
|
||||
inputs_dict["global_attention_mask"] = global_attention_mask
|
||||
return config, inputs_dict
|
||||
|
||||
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
model = TFLEDModel(config=config).get_decoder()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
input_ids = input_ids[:1, :]
|
||||
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||
self.batch_size = 1
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
past_key_values = past_key_values[1]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
|
||||
|
||||
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||
|
||||
|
||||
def prepare_led_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFLEDForConditionalGeneration, TFLEDModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLEDModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LEDConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_decoder_model_past_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_layer_with_bias()
|
||||
assert x is None
|
||||
name = model.get_prefix_bias_name()
|
||||
assert name is None
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["global_attention_mask"] = tf.zeros_like(inputs_dict["attention_mask"])
|
||||
num_global_attn_indices = 2
|
||||
inputs_dict["global_attention_mask"] = tf.where(
|
||||
tf.range(self.model_tester.seq_length)[None, :] < num_global_attn_indices,
|
||||
1,
|
||||
inputs_dict["global_attention_mask"],
|
||||
)
|
||||
|
||||
config.return_dict = True
|
||||
seq_length = self.model_tester.seq_length
|
||||
encoder_seq_length = self.model_tester.encoder_seq_length
|
||||
|
||||
def check_decoder_attentions_output(outputs):
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
||||
)
|
||||
|
||||
def check_encoder_attentions_output(outputs):
|
||||
attentions = [t.numpy() for t in outputs.encoder_attentions]
|
||||
global_attentions = [t.numpy() for t in outputs.encoder_global_attentions]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length],
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(global_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, num_global_attn_indices],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["use_cache"] = False
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
out_len = len(outputs)
|
||||
self.assertEqual(config.output_hidden_states, False)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
self.assertEqual(config.output_hidden_states, False)
|
||||
check_decoder_attentions_output(outputs)
|
||||
|
||||
# Check that output attentions can also be changed via the config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
self.assertEqual(config.output_hidden_states, False)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
check_encoder_attentions_output(outputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# longformer has special attentions which are not
|
||||
# compatible in graph mode
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# TODO(JPLU, PVP) this test should pass!!! PVP:
|
||||
# IMO there is a problem with the signature check.
|
||||
# Test passes for TFLEDModel, but not for TFLEDForConditionalGeneration
|
||||
# IMO the reason is that the tensor variable name cannot be changed
|
||||
# from decoder_input_ids -> input_ids, which poses a BIG restrictions
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if tf.debugging.assert_near(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
msg = "{} != {}".format(a, b)
|
||||
if prefix:
|
||||
msg = prefix + ": " + msg
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return tf.constant(tok_lst, dtype=tf.int32)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
class TFLEDModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_no_head(self):
|
||||
model = TFLEDForConditionalGeneration.from_pretrained("allenai/led-base-16384").led
|
||||
|
||||
# change to intended input here
|
||||
input_ids = _long_tensor([512 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
|
||||
decoder_input_ids = _long_tensor([128 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
|
||||
inputs_dict = prepare_led_inputs_dict(model.config, input_ids, decoder_input_ids)
|
||||
output = model(**inputs_dict)[0]
|
||||
expected_shape = (1, 1024, 768)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
[[2.3050, 2.8279, 0.6531], [-1.8457, -0.1455, -3.5661], [-1.0186, 0.4586, -2.2043]],
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
|
||||
|
||||
def test_inference_with_head(self):
|
||||
model = TFLEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
|
||||
|
||||
# change to intended input here
|
||||
input_ids = _long_tensor([512 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
|
||||
decoder_input_ids = _long_tensor([128 * [0, 31414, 232, 328, 740, 1140, 12695, 69]])
|
||||
inputs_dict = prepare_led_inputs_dict(model.config, input_ids, decoder_input_ids)
|
||||
output = model(**inputs_dict)[0]
|
||||
expected_shape = (1, 1024, model.config.vocab_size)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
[[33.6507, 6.4572, 16.8089], [5.8739, -2.4238, 11.2902], [-3.2139, -4.3149, 4.2783]],
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=TOLERANCE)
|
@ -339,6 +339,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# longformer has special attentions which are not
|
||||
# compatible in graph mode
|
||||
pass
|
||||
|
||||
|
||||
|
@ -30,6 +30,8 @@ PATH_TO_DOC = "docs/source"
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = [
|
||||
# models to ignore for not tested
|
||||
"LEDEncoder", # Building part of bigger (tested) model.
|
||||
"LEDDecoder", # Building part of bigger (tested) model.
|
||||
"BartDecoder", # Building part of bigger (tested) model.
|
||||
"BartEncoder", # Building part of bigger (tested) model.
|
||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||
@ -64,6 +66,8 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = [
|
||||
# models to ignore for model xxx mapping
|
||||
"LEDEncoder",
|
||||
"LEDDecoder",
|
||||
"BartDecoder",
|
||||
"BartEncoder",
|
||||
"DPRContextEncoder",
|
||||
|
Loading…
Reference in New Issue
Block a user