mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
PegasusForConditionalGeneration (torch version) (#6340)
Co-authored-by: Jingqing Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
parent
f6cb0f806e
commit
66fa8ceaea
@ -124,7 +124,9 @@ conversion utilities for the following models:
|
||||
22. `DPR <https://github.com/facebookresearch/DPR>`_ (from Facebook) released with the paper `Dense Passage Retrieval
|
||||
for Open-Domain Question Answering <https://arxiv.org/abs/2004.04906>`_ by Vladimir Karpukhin, Barlas Oğuz, Sewon
|
||||
Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
23. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
23. `Pegasus <https://github.com/google-research/pegasus>`_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization
|
||||
<https://arxiv.org/abs/1912.08777>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
24. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
<https://huggingface.co/users>`_.
|
||||
|
||||
.. toctree::
|
||||
@ -205,6 +207,7 @@ conversion utilities for the following models:
|
||||
model_doc/retribert
|
||||
model_doc/mobilebert
|
||||
model_doc/dpr
|
||||
model_doc/pegasus
|
||||
internal/modeling_utils
|
||||
internal/tokenization_utils
|
||||
internal/pipelines_utils
|
||||
internal/pipelines_utils
|
||||
|
@ -17,13 +17,23 @@ According to the abstract,
|
||||
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
|
||||
|
||||
|
||||
Implementation Notes:
|
||||
Implementation Notes
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use BartTokenizer.encode to get the proper splitting.
|
||||
- The forward pass of ``BartModel`` will create decoder inputs (using the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``) if they are not passed. This is different than some other modeling APIs.
|
||||
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to ``fairseq.encode`` starts with a space.
|
||||
- ``BartForConditionalGeneration.generate`` should be used for conditional generation tasks like summarization, see the example in that docstrings
|
||||
- Models that load the ``"facebook/bart-large-cnn"`` weights will not have a ``mask_token_id``, or be able to perform mask filling tasks.
|
||||
- for training/forward passes that don't involve beam search, pass ``use_cache=False``
|
||||
|
||||
|
||||
BartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForConditionalGeneration
|
||||
:members: generate, forward
|
||||
|
||||
|
||||
BartConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -45,11 +55,7 @@ MBartTokenizer
|
||||
.. autoclass:: transformers.MBartTokenizer
|
||||
:members: build_inputs_with_special_tokens, prepare_translation_batch
|
||||
|
||||
BartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForConditionalGeneration
|
||||
:members: generate, forward
|
||||
|
||||
BartModel
|
||||
~~~~~~~~~~~~~
|
||||
|
111
docs/source/model_doc/pegasus.rst
Normal file
111
docs/source/model_doc/pegasus.rst
Normal file
@ -0,0 +1,111 @@
|
||||
Pegasus
|
||||
----------------------------------------------------
|
||||
**DISCLAIMER:** If you see something strange,
|
||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=sshleifer&labels=&template=bug-report.md&title>`__ and assign
|
||||
@sshleifer.
|
||||
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The Pegasus model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu on Dec 18, 2019.
|
||||
According to the abstract,
|
||||
|
||||
- Pegasus' pretraining task is intentionally similar to summarization: important sentences are removed/masked from an input document and are generated together as one output sequence from the remaining sentences, similar to an extractive summary.
|
||||
- Pegasus achieves SOTA summarization performance on all 12 downstream tasks, as measured by ROUGE and human eval.
|
||||
|
||||
The Authors' code can be found `here <https://github.com/google-research/pegasus>`_
|
||||
|
||||
|
||||
Checkpoints
|
||||
~~~~~~~~~~~
|
||||
The `checkpoints <https://huggingface.co/models?search=pegasus>`_ all checkpoints are finetuned for summarization, besides ``pegasus-large``, whence the other checkpoints are finetuned.
|
||||
- Each checkpoint is 2.2 GB on disk and 568M parameters.
|
||||
- FP16 is not supported (help/ideas on this appreciated!).
|
||||
- Summarizing xsum in fp32 takes about 400ms/sample, with default parameters on a v100 GPU.
|
||||
- For XSUM, The paper reports rouge1,rouge2, rougeL of paper: 47.21/24.56/39.25. As of Aug 9, this port scores 46.91/24.34/39.1.
|
||||
The gap is likely because of different alpha/length_penalty implementations in beam search.
|
||||
|
||||
|
||||
Implementation Notes
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- All models are transformer encoder-decoders with 16 layers in each component.
|
||||
- The implementation is completely inherited from ``BartForConditionalGeneration``
|
||||
- Some key configuration differences:
|
||||
- static, sinusoidal position embeddings
|
||||
- no ``layernorm_embedding`` (``PegasusConfig.normalize_embedding=False``)
|
||||
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix.
|
||||
- ``num_beams=8``
|
||||
- All pretrained pegasus checkpoints are the same besides three attributes: ``tokenizer.model_max_length`` (max input size), ``max_length`` (max num tokens to generate) and ``length_penalty``
|
||||
- Code to convert checkpoints trained in the author's `repo <https://github.com/google-research/pegasus>`_ can be found in ``convert_pegasus_tf_to_pytorch.py``
|
||||
|
||||
|
||||
Usage Example
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
||||
src_text = [
|
||||
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
]
|
||||
|
||||
model_name = 'google/pegasus-xsum'
|
||||
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest').to(torch_device)
|
||||
translated = model.generate(**batch)
|
||||
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
assert tgt_text[0] == "California's largest electricity provider has turned off power to tens of thousands of customers."
|
||||
|
||||
PegasusForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This class inherits all functionality from ``BartForConditionalGeneration``, see that page for method signatures.
|
||||
Available models are listed at `Model List <https://huggingface.co/models?search=pegasus>`__
|
||||
|
||||
PegasusConfig
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
This config fully inherits from ``BartConfig``, but pegasus uses different default values:
|
||||
Up to date parameter values can be seen in `S3 <https://s3.amazonaws.com/models.huggingface.co/bert/google/pegasus-xsum/config.json>`_.
|
||||
As of Aug 10, 2020, they are:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dict(
|
||||
vocab_size=96103,
|
||||
max_position_embeddings=512,
|
||||
d_model=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
decoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layers=16,
|
||||
decoder_layers=16,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation_dropout=0.1,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
is_encoder_decoder=True,
|
||||
normalize_before=True,
|
||||
scale_embedding=True,
|
||||
normalize_embedding=False,
|
||||
add_final_layer_norm=True,
|
||||
static_position_embeddings=True,
|
||||
num_beams=8,
|
||||
activation_function="relu",
|
||||
)
|
||||
|
||||
|
||||
PegasusTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
warning: ``add_tokens`` does not work at the moment.
|
||||
|
||||
.. autoclass:: transformers.PegasusTokenizer
|
||||
:members: __call__, prepare_seq2seq_batch
|
||||
|
||||
|
||||
|
@ -353,6 +353,8 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
||||
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
|
||||
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Pegasus | ``google/pegasus-{dataset}`` | | 16-layer, 1024-hidden, 16-heads, ~568M parameter, 2.2 GB for summary. `model list <https://huggingface.co/models?search=pegasus>`__ |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Longformer | ``allenai/longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
|
||||
| | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
@ -413,6 +413,18 @@ def get_layers_to_copy(n_to_get, tot):
|
||||
12: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
elif tot == 16:
|
||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 8],
|
||||
3: [0, 8, 15],
|
||||
4: [0, 5, 10, 15],
|
||||
6: [0, 3, 6, 9, 12, 15],
|
||||
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
||||
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
||||
16: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
else:
|
||||
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
|
||||
|
||||
|
14
examples/seq2seq/finetune_pegasus_xsum.sh
Executable file
14
examples/seq2seq/finetune_pegasus_xsum.sh
Executable file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# From appendix C of paper https://arxiv.org/abs/1912.08777
|
||||
# Set --gradient_accumulation_steps so that effective batch size is 256 (2*128, 4*64, 8*32, 16*16)
|
||||
python finetune.py \
|
||||
--learning_rate=1e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.25 \
|
||||
--max_source_length 512 --max_target_length 56 \
|
||||
--freeze_embeds --max_target_length 56 --label_smoothing 0.1 \
|
||||
$@
|
@ -37,6 +37,7 @@ from .configuration_marian import MarianConfig
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
@ -150,6 +151,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
||||
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
@ -287,6 +289,7 @@ if is_torch_available():
|
||||
XLMForMultipleChoice,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_bart import (
|
||||
PretrainedBartModel,
|
||||
BartForSequenceClassification,
|
||||
|
@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_reformer import ReformerConfig
|
||||
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
@ -81,6 +82,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("albert", AlbertConfig,),
|
||||
("camembert", CamembertConfig,),
|
||||
("xlm-roberta", XLMRobertaConfig,),
|
||||
("pegasus", PegasusConfig),
|
||||
("marian", MarianConfig,),
|
||||
("mbart", MBartConfig,),
|
||||
("bart", BartConfig,),
|
||||
|
@ -18,6 +18,7 @@
|
||||
import logging
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,8 +32,73 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
|
||||
}
|
||||
BART_CONFIG_ARGS_DOC = r"""
|
||||
Args:
|
||||
vocab_size (:obj:`int`, optional, defaults to 50265):
|
||||
defines the different tokens that can be represented by `inputs_ids` passed to the forward method.
|
||||
d_model (:obj:`int`, optional, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, optional, defaults to 12):
|
||||
Number of encoder layers, 16 for pegasus, 6 for bart-base and marian
|
||||
decoder_layers (:obj:`int`, optional, defaults to 12):
|
||||
Number of decoder layers, 16 for pegasus, 6 for bart-base and marian
|
||||
encoder_attention_heads (:obj:`int`, optional, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (:obj:`int`, optional, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
dropout (:obj:`float`, optional, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (:obj:`float`, optional, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (:obj:`int`, optional, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (:obj:`float`, optional, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
add_bias_logits (:obj:`int`, optional, defaults to False):
|
||||
True for marian only.
|
||||
normalize_before (:obj:`bool`, optional, defaults to False):
|
||||
Call layernorm before attention ops. True for pegasus, mbart. False for bart. FIXME: marian?
|
||||
normalize_embedding (:obj:`bool`, optional, defaults to True):
|
||||
Call layernorm after embeddings. Only True for Bart.
|
||||
static_position_embeddings (:obj:`bool`, optional, defaults to False):
|
||||
Don't learn positional embeddings, use sinusoidal. True for marian, pegasus.
|
||||
add_final_layer_norm (:obj:`bool`, optional, defaults to False):
|
||||
Why not add another layernorm?
|
||||
scale_embedding (:obj:`bool`, optional, defaults to False):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
eos_token_id (:obj:`int`, optional, defaults to 2)
|
||||
End of stream token id.
|
||||
pad_token_id (:obj:`int`, optional, defaults to 1)
|
||||
Padding token id.
|
||||
bos_token_id (:obj:`int`, optional, defaults to 0)
|
||||
Beginning of stream token id.
|
||||
encoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
|
||||
Google "layerdrop arxiv", as its not explainable in one line.
|
||||
decoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
|
||||
Google "layerdrop arxiv", as its not explainable in one line.
|
||||
extra_pos_embeddings: (:obj:`int`, optional, defaults to 2):
|
||||
How many extra learned positional embeddings to use. Should be pad_token_id+1 for bart.
|
||||
num_labels: (:obj:`int`, optional, defaults to 2):
|
||||
for SequenceClassification
|
||||
is_encoder_decoder (:obj:`int`, optional, defaults to True):
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
||||
class BartConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class for Bart. Parameters are renamed from the fairseq implementation
|
||||
@ -42,7 +108,7 @@ class BartConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
extra_pos_embeddings=2,
|
||||
extra_pos_embeddings=2, # FIXME(@sshleifer): delete?
|
||||
activation_function="gelu",
|
||||
vocab_size=50265,
|
||||
d_model=1024,
|
||||
@ -81,6 +147,7 @@ class BartConfig(PretrainedConfig):
|
||||
|
||||
>>> config = BartConfig.from_pretrained('facebook/bart-large')
|
||||
>>> model = BartModel(config)
|
||||
|
||||
"""
|
||||
if "hidden_size" in common_kwargs:
|
||||
raise ValueError("hidden size is called d_model")
|
||||
@ -146,3 +213,4 @@ class BartConfig(PretrainedConfig):
|
||||
|
||||
class MBartConfig(BartConfig):
|
||||
model_type = "mbart"
|
||||
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""
|
||||
|
62
src/transformers/configuration_pegasus.py
Normal file
62
src/transformers/configuration_pegasus.py
Normal file
@ -0,0 +1,62 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PEGASUS model configuration """
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
|
||||
from .file_utils import add_start_docstrings_to_callable
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULTS = dict(
|
||||
vocab_size=96103,
|
||||
max_position_embeddings=512,
|
||||
d_model=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
decoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layers=16,
|
||||
decoder_layers=16,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation_dropout=0.1,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
is_encoder_decoder=True,
|
||||
normalize_before=True,
|
||||
scale_embedding=True,
|
||||
normalize_embedding=False,
|
||||
add_final_layer_norm=True,
|
||||
static_position_embeddings=True,
|
||||
num_beams=8,
|
||||
activation_function="relu",
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
||||
class PegasusConfig(BartConfig):
|
||||
r"""
|
||||
:class:`~transformers.PegasusConfig` is the configuration class to store the configuration of a
|
||||
`PegasusModel`.
|
||||
"""
|
||||
model_type = "pegasus"
|
||||
# The implementation of the config object is in BartConfig
|
||||
|
||||
@property
|
||||
def default_config_parameters(self):
|
||||
return DEFAULTS
|
167
src/transformers/convert_pegasus_tf_to_pytorch.py
Normal file
167
src/transformers/convert_pegasus_tf_to_pytorch.py
Normal file
@ -0,0 +1,167 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from transformers.configuration_pegasus import DEFAULTS
|
||||
|
||||
|
||||
PATTERNS = [
|
||||
# replace left string with right string to get the relevant state_dict key (identical state dict to bart)
|
||||
["memory_attention", "encoder_attn"],
|
||||
["attention", "attn"],
|
||||
["/", "."],
|
||||
[".LayerNorm.gamma", "_layer_norm.weight"],
|
||||
[".LayerNorm.beta", "_layer_norm.bias"],
|
||||
["r.layer_", "r.layers."],
|
||||
["output_proj", "out_proj"],
|
||||
["ffn.dense_1.", "fc2."],
|
||||
["ffn.dense.", "fc1."],
|
||||
["ffn_layer_norm", "final_layer_norm"],
|
||||
["kernel", "weight"],
|
||||
["encoder_layer_norm.", "encoder.layer_norm."],
|
||||
["decoder_layer_norm.", "decoder.layer_norm."],
|
||||
["embeddings.weights", "shared.weight"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
|
||||
for pegasus_name, bart_name in PATTERNS:
|
||||
k = k.replace(pegasus_name, bart_name)
|
||||
return k
|
||||
|
||||
|
||||
# See appendix C of paper for all hyperparams
|
||||
max_gen_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
"large": 256, # @sshleifer chose arbitrarily
|
||||
}
|
||||
max_model_length = {
|
||||
"xsum": 512,
|
||||
"cnn_dailymail": 1024,
|
||||
"newsroom": 512,
|
||||
"wikihow": 512,
|
||||
"multi_news": 1024,
|
||||
"reddit_tifu": 512,
|
||||
"big_patent": 1024,
|
||||
"arxiv": 1024,
|
||||
"pubmed": 1024,
|
||||
"gigaword": 128,
|
||||
"aeslc": 512,
|
||||
"billsum": 1024,
|
||||
"large": 1024,
|
||||
}
|
||||
|
||||
expected_alpha = {
|
||||
"multinews": 0.9,
|
||||
"wikihow": 0.6,
|
||||
"reddit_tifu": 0.6,
|
||||
"big_patent": 0.7,
|
||||
"gigaword": 0.6,
|
||||
"aeslc": 0.6,
|
||||
"billsum": 0.6,
|
||||
} # otherwise 0.8
|
||||
# TODO(SS): one constant
|
||||
|
||||
|
||||
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
|
||||
cfg_kwargs = DEFAULTS.copy()
|
||||
cfg_kwargs.update(cfg_updates)
|
||||
|
||||
cfg = PegasusConfig(**cfg_updates)
|
||||
bart = PegasusForConditionalGeneration(cfg)
|
||||
sd = bart.model.state_dict()
|
||||
mapping = {}
|
||||
for k, v in tf_weights.items():
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in sd:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
|
||||
if "dense" in k or "proj" in new_k:
|
||||
v = v.T
|
||||
mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype)
|
||||
assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}"
|
||||
# make sure embedding.padding_idx is respected
|
||||
mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1])
|
||||
mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"]
|
||||
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
|
||||
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
|
||||
mapping.update(**empty_biases)
|
||||
missing, extra = bart.model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return bart
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["Adafactor", "global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any([pat in name for pat in ignore_name])
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
|
||||
# save tokenizer first
|
||||
dataset = Path(ckpt_path).parent.name
|
||||
desired_max_model_length = max_model_length[dataset]
|
||||
tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
|
||||
assert tok.model_max_length == desired_max_model_length
|
||||
tok.save_pretrained(save_dir)
|
||||
|
||||
# convert model
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
|
||||
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
if args.save_dir is None:
|
||||
args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}"
|
||||
convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)
|
@ -34,6 +34,7 @@ from .configuration_auto import (
|
||||
LongformerConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
@ -125,6 +126,7 @@ from .modeling_mobilebert import (
|
||||
MobileBertModel,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_reformer import (
|
||||
ReformerForMaskedLM,
|
||||
ReformerForQuestionAnswering,
|
||||
@ -283,6 +285,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
|
@ -19,9 +19,7 @@ from .configuration_marian import MarianConfig
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
|
||||
]
|
||||
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
|
||||
|
||||
|
||||
class MarianMTModel(BartForConditionalGeneration):
|
||||
|
46
src/transformers/modeling_pegasus.py
Normal file
46
src/transformers/modeling_pegasus.py
Normal file
@ -0,0 +1,46 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Pegasus model, ported from https://github.com/google-research/pegasus"""
|
||||
|
||||
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
|
||||
|
||||
|
||||
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
||||
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
config_class = PegasusConfig
|
||||
r"""
|
||||
Pytorch version of google's pegasus model for summarization.
|
||||
Model API is identical to BartForConditionalGeneration.
|
||||
Available models are listed at `Model List <https://huggingface.co/models?search=pegasus>`__
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import PegasusTokenizer, PegasusForConditionalGeneration
|
||||
>>> from typing import List
|
||||
>>> PGE_ARTICLE = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
||||
>>> mname = "google/pegasus-xsum"
|
||||
|
||||
>>> model = PegasusForConditionalGeneration.from_pretrained(mname)
|
||||
>>> tok = PegasusTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE]) # don't need tgt_text for inference
|
||||
>>> gen = model.generate(**batch) # for forward pass: model(**batch)
|
||||
>>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
|
||||
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
|
||||
|
||||
"""
|
||||
# All the code is in src/transformers/modeling_bart.py
|
@ -30,8 +30,11 @@ from .configuration_auto import (
|
||||
FlaubertConfig,
|
||||
GPT2Config,
|
||||
LongformerConfig,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
@ -41,8 +44,6 @@ from .configuration_auto import (
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mobilebert import MobileBertConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer, MBartTokenizer
|
||||
@ -58,6 +59,7 @@ from .tokenization_longformer import LongformerTokenizer
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
|
||||
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
||||
from .tokenization_pegasus import PegasusTokenizer
|
||||
from .tokenization_reformer import ReformerTokenizer
|
||||
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
|
||||
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
@ -79,6 +81,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||
(AlbertConfig, (AlbertTokenizer, None)),
|
||||
(CamembertConfig, (CamembertTokenizer, None)),
|
||||
(PegasusConfig, (PegasusTokenizer, None)),
|
||||
(MBartConfig, (MBartTokenizer, None)),
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||
(MarianConfig, (MarianTokenizer, None)),
|
||||
|
193
src/transformers/tokenization_pegasus.py
Normal file
193
src/transformers/tokenization_pegasus.py
Normal file
@ -0,0 +1,193 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.tokenization_reformer import ReformerTokenizer
|
||||
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
|
||||
|
||||
class PegasusTokenizer(ReformerTokenizer):
|
||||
offset = 103 # entries 2-104 are only used for pretraining
|
||||
vocab_files_names = {"vocab_file": "spiece.model"}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Dont use reserved words added_token_encoder, added_tokens_decoder because of
|
||||
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
|
||||
assert len(self.added_tokens_decoder) == 0
|
||||
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
|
||||
# entries 2-104 are only used for pretraining and called unk_2, ...unk_104
|
||||
self.encoder.update({i: f"unk_{i}" for i in range(2, self.offset + 2)})
|
||||
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.decoder:
|
||||
return self.decoder[token]
|
||||
elif token in self.added_tokens_decoder:
|
||||
return self.added_tokens_decoder[token]
|
||||
sp_id = self.sp_model.piece_to_id(token)
|
||||
return sp_id + self.offset
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.encoder:
|
||||
return self.encoder[index]
|
||||
elif index in self.added_tokens_encoder:
|
||||
return self.added_tokens_encoder[index]
|
||||
else:
|
||||
# assert index > self.offset, f"cannot decode ids between 2 and {self.offset}. Got {index}"
|
||||
token = self.sp_model.IdToPiece(index - self.offset)
|
||||
return token
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.sp_model) + self.offset
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def num_special_tokens_to_add(self, pair=False):
|
||||
"""Just EOS"""
|
||||
return 1
|
||||
|
||||
def _special_token_mask(self, seq):
|
||||
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
|
||||
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
|
||||
assert all_special_ids == set([0, 1])
|
||||
return [1 if x in all_special_ids else 0 for x in seq]
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
|
||||
if already_has_special_tokens:
|
||||
return self._special_token_mask(token_ids_0)
|
||||
elif token_ids_1 is None:
|
||||
return self._special_token_mask(token_ids_0) + [1]
|
||||
else:
|
||||
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
|
||||
- single sequence: ``X </s>``
|
||||
- pair of sequences: ``A B </s>`` (not intended use)
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Prepare model inputs for summarization or translation.
|
||||
|
||||
Arguments:
|
||||
src_texts: (:obj:`list`):
|
||||
list of documents to summarize or source language texts
|
||||
tgt_texts: (:obj:`list`, `optional`):
|
||||
list of tgt language texts or summaries.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
|
||||
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries)
|
||||
If left unset or set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
|
||||
This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
|
||||
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
|
||||
for k, v in decoder_inputs.items():
|
||||
model_inputs[f"decoder_{k}"] = v
|
||||
return model_inputs
|
@ -45,6 +45,7 @@ if is_torch_available():
|
||||
_prepare_bart_decoder_inputs,
|
||||
SinusoidalPositionalEmbedding,
|
||||
)
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -478,7 +479,6 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
self.assertFalse(model.config.is_valid_mbart())
|
||||
tok = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
|
||||
dct = tok.batch_encode_plus(
|
||||
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
|
||||
|
@ -23,14 +23,13 @@ RO_CODE = 250020
|
||||
|
||||
|
||||
@require_torch
|
||||
class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
|
||||
maxDiff = 1000 # longer string compare tracebacks
|
||||
checkpoint_name = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
@ -43,7 +42,7 @@ class AbstractMBartIntegrationTest(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-en-ro"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
@ -73,7 +72,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
]
|
||||
),
|
||||
}
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
|
||||
net_input["attention_mask"] = net_input["input_ids"].ne(1)
|
||||
with torch.no_grad():
|
||||
logits, *other_stuff = model(**net_input)
|
||||
|
||||
@ -125,7 +124,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||
class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-cc25"
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
|
79
tests/test_modeling_pegasus.py
Normal file
79
tests/test_modeling_pegasus.py
Normal file
@ -0,0 +1,79 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bart import PGE_ARTICLE
|
||||
from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """
|
||||
|
||||
|
||||
@require_torch
|
||||
class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
checkpoint_name = "google/pegasus-xsum"
|
||||
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
|
||||
tgt_text = [
|
||||
"California's largest electricity provider has turned off power to tens of thousands of customers.",
|
||||
"N-Dubz have revealed they weren't expecting to get four nominations at this year's Mobo Awards.",
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
|
||||
|
||||
@slow
|
||||
def test_pegasus_xsum_summary(self):
|
||||
assert self.tokenizer.model_max_length == 512
|
||||
inputs = self.tokenizer(self.src_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(
|
||||
torch_device
|
||||
)
|
||||
assert inputs.input_ids.shape == (2, 421)
|
||||
translated_tokens = self.model.generate(**inputs)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text, decoded)
|
||||
|
||||
if "cuda" not in torch_device:
|
||||
return
|
||||
# Demonstrate fp16 issue, Contributions welcome!
|
||||
self.model.half()
|
||||
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
|
||||
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"]
|
||||
self.assertListEqual(decoded, bad_fp16_result)
|
||||
|
||||
|
||||
class PegasusConfigTests(unittest.TestCase):
|
||||
def test_all_config_max_lengths(self):
|
||||
expected_max_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
}
|
||||
failures = []
|
||||
pegasus_prefix = "google/pegasus"
|
||||
for dataset, max_len in expected_max_length.items():
|
||||
mname = f"{pegasus_prefix}-{dataset}"
|
||||
cfg = AutoConfig.from_pretrained(mname)
|
||||
if cfg.max_length != max_len:
|
||||
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
|
||||
if failures == []:
|
||||
return
|
||||
# error
|
||||
all_fails = "\n".join(failures)
|
||||
raise AssertionError(f"The following configs have unexpected settings: {all_fails}")
|
69
tests/test_tokenization_pegasus.py
Normal file
69
tests/test_tokenization_pegasus.py
Normal file
@ -0,0 +1,69 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.tokenization_pegasus import PegasusTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = PegasusTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
save_dir = Path(self.tmpdirname)
|
||||
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
|
||||
if not (save_dir / spm_file).exists():
|
||||
tokenizer = self.pegasus_large_tokenizer
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
@cached_property
|
||||
def pegasus_large_tokenizer(self):
|
||||
return PegasusTokenizer.from_pretrained("google/pegasus-large")
|
||||
|
||||
@unittest.skip("add_tokens does not work yet")
|
||||
def test_swap_special_token(self):
|
||||
pass
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
||||
if not kwargs:
|
||||
return self.pegasus_large_tokenizer
|
||||
else:
|
||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return ("This is a test", "This is a test")
|
||||
|
||||
def test_pegasus_large_tokenizer_settings(self):
|
||||
tokenizer = self.pegasus_large_tokenizer
|
||||
# The tracebacks for the following asserts are **better** without messages or self.assertEqual
|
||||
assert tokenizer.vocab_size == 96103
|
||||
assert tokenizer.pad_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
assert tokenizer.offset == 103
|
||||
assert tokenizer.unk_token_id == tokenizer.offset + 2 == 105
|
||||
assert tokenizer.unk_token == "<unk>"
|
||||
assert tokenizer.mask_token is None
|
||||
assert tokenizer.mask_token_id is None
|
||||
assert tokenizer.model_max_length == 1024
|
||||
raw_input_str = "To ensure a smooth flow of bank resolutions."
|
||||
desired_result = [413, 615, 114, 2291, 1971, 113, 1679, 10710, 107, 1]
|
||||
ids = tokenizer([raw_input_str], return_tensors=None).input_ids[0]
|
||||
self.assertListEqual(desired_result, ids)
|
||||
assert tokenizer.convert_ids_to_tokens([0, 1, 2]) == ["<pad>", "</s>", "unk_2"]
|
||||
|
||||
@require_torch
|
||||
def test_pegasus_large_seq2seq_truncation(self):
|
||||
src_texts = ["This is going to be way too long" * 10000, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "decoder_input_ids" in batch # because tgt_texts was specified
|
||||
assert batch.decoder_input_ids.shape == (2, 5)
|
||||
assert batch.decoder_attention_mask.shape == (2, 5)
|
||||
assert len(batch) == 4 # no extra keys
|
Loading…
Reference in New Issue
Block a user