PegasusForConditionalGeneration (torch version) (#6340)

Co-authored-by: Jingqing  Zhang <jingqing.zhang15@imperial.ac.uk>
This commit is contained in:
Sam Shleifer 2020-08-11 14:31:23 -04:00 committed by GitHub
parent f6cb0f806e
commit 66fa8ceaea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 860 additions and 20 deletions

View File

@ -124,7 +124,9 @@ conversion utilities for the following models:
22. `DPR <https://github.com/facebookresearch/DPR>`_ (from Facebook) released with the paper `Dense Passage Retrieval
for Open-Domain Question Answering <https://arxiv.org/abs/2004.04906>`_ by Vladimir Karpukhin, Barlas Oğuz, Sewon
Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
23. `Other community models <https://huggingface.co/models>`_, contributed by the `community
23. `Pegasus <https://github.com/google-research/pegasus>`_ (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization
<https://arxiv.org/abs/1912.08777>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
24. `Other community models <https://huggingface.co/models>`_, contributed by the `community
<https://huggingface.co/users>`_.
.. toctree::
@ -205,6 +207,7 @@ conversion utilities for the following models:
model_doc/retribert
model_doc/mobilebert
model_doc/dpr
model_doc/pegasus
internal/modeling_utils
internal/tokenization_utils
internal/pipelines_utils
internal/pipelines_utils

View File

@ -17,13 +17,23 @@ According to the abstract,
The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
Implementation Notes:
Implementation Notes
~~~~~~~~~~~~~~~~~~~~
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use BartTokenizer.encode to get the proper splitting.
- The forward pass of ``BartModel`` will create decoder inputs (using the helper function ``transformers.modeling_bart._prepare_bart_decoder_inputs``) if they are not passed. This is different than some other modeling APIs.
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to ``fairseq.encode`` starts with a space.
- ``BartForConditionalGeneration.generate`` should be used for conditional generation tasks like summarization, see the example in that docstrings
- Models that load the ``"facebook/bart-large-cnn"`` weights will not have a ``mask_token_id``, or be able to perform mask filling tasks.
- for training/forward passes that don't involve beam search, pass ``use_cache=False``
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForConditionalGeneration
:members: generate, forward
BartConfig
~~~~~~~~~~~~~~~~~~~~~
@ -45,11 +55,7 @@ MBartTokenizer
.. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_translation_batch
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForConditionalGeneration
:members: generate, forward
BartModel
~~~~~~~~~~~~~

View File

@ -0,0 +1,111 @@
Pegasus
----------------------------------------------------
**DISCLAIMER:** If you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=sshleifer&labels=&template=bug-report.md&title>`__ and assign
@sshleifer.
Overview
~~~~~~~~~~~~~~~~~~~~~
The Pegasus model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu on Dec 18, 2019.
According to the abstract,
- Pegasus' pretraining task is intentionally similar to summarization: important sentences are removed/masked from an input document and are generated together as one output sequence from the remaining sentences, similar to an extractive summary.
- Pegasus achieves SOTA summarization performance on all 12 downstream tasks, as measured by ROUGE and human eval.
The Authors' code can be found `here <https://github.com/google-research/pegasus>`_
Checkpoints
~~~~~~~~~~~
The `checkpoints <https://huggingface.co/models?search=pegasus>`_ all checkpoints are finetuned for summarization, besides ``pegasus-large``, whence the other checkpoints are finetuned.
- Each checkpoint is 2.2 GB on disk and 568M parameters.
- FP16 is not supported (help/ideas on this appreciated!).
- Summarizing xsum in fp32 takes about 400ms/sample, with default parameters on a v100 GPU.
- For XSUM, The paper reports rouge1,rouge2, rougeL of paper: 47.21/24.56/39.25. As of Aug 9, this port scores 46.91/24.34/39.1.
The gap is likely because of different alpha/length_penalty implementations in beam search.
Implementation Notes
~~~~~~~~~~~~~~~~~~~~
- All models are transformer encoder-decoders with 16 layers in each component.
- The implementation is completely inherited from ``BartForConditionalGeneration``
- Some key configuration differences:
- static, sinusoidal position embeddings
- no ``layernorm_embedding`` (``PegasusConfig.normalize_embedding=False``)
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix.
- ``num_beams=8``
- All pretrained pegasus checkpoints are the same besides three attributes: ``tokenizer.model_max_length`` (max input size), ``max_length`` (max num tokens to generate) and ``length_penalty``
- Code to convert checkpoints trained in the author's `repo <https://github.com/google-research/pegasus>`_ can be found in ``convert_pegasus_tf_to_pytorch.py``
Usage Example
~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
src_text = [
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
]
model_name = 'google/pegasus-xsum'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest').to(torch_device)
translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
assert tgt_text[0] == "California's largest electricity provider has turned off power to tens of thousands of customers."
PegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This class inherits all functionality from ``BartForConditionalGeneration``, see that page for method signatures.
Available models are listed at `Model List <https://huggingface.co/models?search=pegasus>`__
PegasusConfig
~~~~~~~~~~~~~~~~~~~
This config fully inherits from ``BartConfig``, but pegasus uses different default values:
Up to date parameter values can be seen in `S3 <https://s3.amazonaws.com/models.huggingface.co/bert/google/pegasus-xsum/config.json>`_.
As of Aug 10, 2020, they are:
.. code-block:: python
dict(
vocab_size=96103,
max_position_embeddings=512,
d_model=1024,
encoder_ffn_dim=4096,
decoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_attention_heads=16,
encoder_layers=16,
decoder_layers=16,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
pad_token_id=0,
eos_token_id=1,
is_encoder_decoder=True,
normalize_before=True,
scale_embedding=True,
normalize_embedding=False,
add_final_layer_norm=True,
static_position_embeddings=True,
num_beams=8,
activation_function="relu",
)
PegasusTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
warning: ``add_tokens`` does not work at the moment.
.. autoclass:: transformers.PegasusTokenizer
:members: __call__, prepare_seq2seq_batch

View File

@ -353,6 +353,8 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Pegasus | ``google/pegasus-{dataset}`` | | 16-layer, 1024-hidden, 16-heads, ~568M parameter, 2.2 GB for summary. `model list <https://huggingface.co/models?search=pegasus>`__ |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Longformer | ``allenai/longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
| | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+

View File

@ -413,6 +413,18 @@ def get_layers_to_copy(n_to_get, tot):
12: all_layers,
}
return layers_to_copy[n_to_get]
elif tot == 16:
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
16: all_layers,
}
return layers_to_copy[n_to_get]
else:
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch

View File

@ -0,0 +1,14 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
# From appendix C of paper https://arxiv.org/abs/1912.08777
# Set --gradient_accumulation_steps so that effective batch size is 256 (2*128, 4*64, 8*32, 16*16)
python finetune.py \
--learning_rate=1e-4 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.25 \
--max_source_length 512 --max_target_length 56 \
--freeze_embeds --max_target_length 56 --label_smoothing 0.1 \
$@

View File

@ -37,6 +37,7 @@ from .configuration_marian import MarianConfig
from .configuration_mmbt import MMBTConfig
from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@ -150,6 +151,7 @@ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_longformer import LongformerTokenizer, LongformerTokenizerFast
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
@ -287,6 +289,7 @@ if is_torch_available():
XLMForMultipleChoice,
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_bart import (
PretrainedBartModel,
BartForSequenceClassification,

View File

@ -32,6 +32,7 @@ from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
from .configuration_marian import MarianConfig
from .configuration_mobilebert import MobileBertConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_pegasus import PegasusConfig
from .configuration_reformer import ReformerConfig
from .configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@ -81,6 +82,7 @@ CONFIG_MAPPING = OrderedDict(
("albert", AlbertConfig,),
("camembert", CamembertConfig,),
("xlm-roberta", XLMRobertaConfig,),
("pegasus", PegasusConfig),
("marian", MarianConfig,),
("mbart", MBartConfig,),
("bart", BartConfig,),

View File

@ -18,6 +18,7 @@
import logging
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_callable
logger = logging.getLogger(__name__)
@ -31,8 +32,73 @@ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
}
BART_CONFIG_ARGS_DOC = r"""
Args:
vocab_size (:obj:`int`, optional, defaults to 50265):
defines the different tokens that can be represented by `inputs_ids` passed to the forward method.
d_model (:obj:`int`, optional, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (:obj:`int`, optional, defaults to 12):
Number of encoder layers, 16 for pegasus, 6 for bart-base and marian
decoder_layers (:obj:`int`, optional, defaults to 12):
Number of decoder layers, 16 for pegasus, 6 for bart-base and marian
encoder_attention_heads (:obj:`int`, optional, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, optional, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, optional, defaults to 4096):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
The non-linear activation function (function or string) in the encoder and pooler.
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
dropout (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, optional, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (:obj:`float`, optional, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, optional, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (:obj:`int`, optional, defaults to 1024):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
add_bias_logits (:obj:`int`, optional, defaults to False):
True for marian only.
normalize_before (:obj:`bool`, optional, defaults to False):
Call layernorm before attention ops. True for pegasus, mbart. False for bart. FIXME: marian?
normalize_embedding (:obj:`bool`, optional, defaults to True):
Call layernorm after embeddings. Only True for Bart.
static_position_embeddings (:obj:`bool`, optional, defaults to False):
Don't learn positional embeddings, use sinusoidal. True for marian, pegasus.
add_final_layer_norm (:obj:`bool`, optional, defaults to False):
Why not add another layernorm?
scale_embedding (:obj:`bool`, optional, defaults to False):
Scale embeddings by diving by sqrt(d_model).
eos_token_id (:obj:`int`, optional, defaults to 2)
End of stream token id.
pad_token_id (:obj:`int`, optional, defaults to 1)
Padding token id.
bos_token_id (:obj:`int`, optional, defaults to 0)
Beginning of stream token id.
encoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
Google "layerdrop arxiv", as its not explainable in one line.
decoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
Google "layerdrop arxiv", as its not explainable in one line.
extra_pos_embeddings: (:obj:`int`, optional, defaults to 2):
How many extra learned positional embeddings to use. Should be pad_token_id+1 for bart.
num_labels: (:obj:`int`, optional, defaults to 2):
for SequenceClassification
is_encoder_decoder (:obj:`int`, optional, defaults to True):
True
"""
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
class BartConfig(PretrainedConfig):
r"""
Configuration class for Bart. Parameters are renamed from the fairseq implementation
@ -42,7 +108,7 @@ class BartConfig(PretrainedConfig):
def __init__(
self,
activation_dropout=0.0,
extra_pos_embeddings=2,
extra_pos_embeddings=2, # FIXME(@sshleifer): delete?
activation_function="gelu",
vocab_size=50265,
d_model=1024,
@ -81,6 +147,7 @@ class BartConfig(PretrainedConfig):
>>> config = BartConfig.from_pretrained('facebook/bart-large')
>>> model = BartModel(config)
"""
if "hidden_size" in common_kwargs:
raise ValueError("hidden size is called d_model")
@ -146,3 +213,4 @@ class BartConfig(PretrainedConfig):
class MBartConfig(BartConfig):
model_type = "mbart"
"""See real config values at https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json."""

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PEGASUS model configuration """
import logging
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
from .file_utils import add_start_docstrings_to_callable
logger = logging.getLogger(__name__)
DEFAULTS = dict(
vocab_size=96103,
max_position_embeddings=512,
d_model=1024,
encoder_ffn_dim=4096,
decoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_attention_heads=16,
encoder_layers=16,
decoder_layers=16,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
pad_token_id=0,
eos_token_id=1,
is_encoder_decoder=True,
normalize_before=True,
scale_embedding=True,
normalize_embedding=False,
add_final_layer_norm=True,
static_position_embeddings=True,
num_beams=8,
activation_function="relu",
)
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
class PegasusConfig(BartConfig):
r"""
:class:`~transformers.PegasusConfig` is the configuration class to store the configuration of a
`PegasusModel`.
"""
model_type = "pegasus"
# The implementation of the config object is in BartConfig
@property
def default_config_parameters(self):
return DEFAULTS

View File

@ -0,0 +1,167 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
from typing import Dict
import tensorflow as tf
import torch
from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.configuration_pegasus import DEFAULTS
PATTERNS = [
# replace left string with right string to get the relevant state_dict key (identical state dict to bart)
["memory_attention", "encoder_attn"],
["attention", "attn"],
["/", "."],
[".LayerNorm.gamma", "_layer_norm.weight"],
[".LayerNorm.beta", "_layer_norm.bias"],
["r.layer_", "r.layers."],
["output_proj", "out_proj"],
["ffn.dense_1.", "fc2."],
["ffn.dense.", "fc1."],
["ffn_layer_norm", "final_layer_norm"],
["kernel", "weight"],
["encoder_layer_norm.", "encoder.layer_norm."],
["decoder_layer_norm.", "decoder.layer_norm."],
["embeddings.weights", "shared.weight"],
]
def rename_state_dict_key(k):
for pegasus_name, bart_name in PATTERNS:
k = k.replace(pegasus_name, bart_name)
return k
# See appendix C of paper for all hyperparams
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8
# TODO(SS): one constant
def convert_pegasus_to_bart(tf_weights: dict, cfg_updates: dict) -> PegasusForConditionalGeneration:
cfg_kwargs = DEFAULTS.copy()
cfg_kwargs.update(cfg_updates)
cfg = PegasusConfig(**cfg_updates)
bart = PegasusForConditionalGeneration(cfg)
sd = bart.model.state_dict()
mapping = {}
for k, v in tf_weights.items():
new_k = rename_state_dict_key(k)
if new_k not in sd:
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
if "dense" in k or "proj" in new_k:
v = v.T
mapping[new_k] = torch.tensor(v, dtype=sd[new_k].dtype)
assert v.shape == sd[new_k].shape, f"{new_k}, {k}, {v.shape}, {sd[new_k].shape}"
# make sure embedding.padding_idx is respected
mapping["shared.weight"][cfg.pad_token_id] = torch.zeros_like(mapping["shared.weight"][cfg.pad_token_id + 1])
mapping["encoder.embed_tokens.weight"] = mapping["shared.weight"]
mapping["decoder.embed_tokens.weight"] = mapping["shared.weight"]
empty_biases = {k: torch.zeros_like(v) for k, v in sd.items() if k.endswith("bias") and k not in mapping}
mapping.update(**empty_biases)
missing, extra = bart.model.load_state_dict(mapping, strict=False)
unexpected_missing = [
k for k in missing if k not in ["encoder.embed_positions.weight", "decoder.embed_positions.weight"]
]
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
assert extra == [], f"no matches found for the following tf keys {extra}"
return bart
def get_tf_weights_as_numpy(path="./ckpt/aeslc/model.ckpt-32000") -> Dict:
init_vars = tf.train.list_variables(path)
tf_weights = {}
ignore_name = ["Adafactor", "global_step"]
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
skip_key = any([pat in name for pat in ignore_name])
if skip_key:
continue
array = tf.train.load_variable(path, name)
tf_weights[name] = array
return tf_weights
def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
# save tokenizer first
dataset = Path(ckpt_path).parent.name
desired_max_model_length = max_model_length[dataset]
tok = PegasusTokenizer.from_pretrained("sshleifer/pegasus", model_max_length=desired_max_model_length)
assert tok.model_max_length == desired_max_model_length
tok.save_pretrained(save_dir)
# convert model
tf_weights = get_tf_weights_as_numpy(ckpt_path)
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
parser.add_argument("save_dir", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
if args.save_dir is None:
args.save_dir = f"pegasus/{Path(args.tf_ckpt_path).parent.name}"
convert_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir)

View File

@ -34,6 +34,7 @@ from .configuration_auto import (
LongformerConfig,
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
@ -125,6 +126,7 @@ from .modeling_mobilebert import (
MobileBertModel,
)
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_reformer import (
ReformerForMaskedLM,
ReformerForQuestionAnswering,
@ -283,6 +285,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
(T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration),
(MarianConfig, MarianMTModel),
(BartConfig, BartForConditionalGeneration),
(EncoderDecoderConfig, EncoderDecoderModel),

View File

@ -19,9 +19,7 @@ from .configuration_marian import MarianConfig
from .modeling_bart import BartForConditionalGeneration
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
]
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
class MarianMTModel(BartForConditionalGeneration):

View File

@ -0,0 +1,46 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Pegasus model, ported from https://github.com/google-research/pegasus"""
from .configuration_pegasus import PegasusConfig
from .file_utils import add_start_docstrings
from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
class PegasusForConditionalGeneration(BartForConditionalGeneration):
config_class = PegasusConfig
r"""
Pytorch version of google's pegasus model for summarization.
Model API is identical to BartForConditionalGeneration.
Available models are listed at `Model List <https://huggingface.co/models?search=pegasus>`__
Examples::
>>> from transformers import PegasusTokenizer, PegasusForConditionalGeneration
>>> from typing import List
>>> PGE_ARTICLE = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
>>> mname = "google/pegasus-xsum"
>>> model = PegasusForConditionalGeneration.from_pretrained(mname)
>>> tok = PegasusTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE]) # don't need tgt_text for inference
>>> gen = model.generate(**batch) # for forward pass: model(**batch)
>>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
"""
# All the code is in src/transformers/modeling_bart.py

View File

@ -30,8 +30,11 @@ from .configuration_auto import (
FlaubertConfig,
GPT2Config,
LongformerConfig,
MarianConfig,
MBartConfig,
MobileBertConfig,
OpenAIGPTConfig,
PegasusConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
@ -41,8 +44,6 @@ from .configuration_auto import (
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_marian import MarianConfig
from .configuration_mobilebert import MobileBertConfig
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bart import BartTokenizer, MBartTokenizer
@ -58,6 +59,7 @@ from .tokenization_longformer import LongformerTokenizer
from .tokenization_marian import MarianTokenizer
from .tokenization_mobilebert import MobileBertTokenizer, MobileBertTokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_pegasus import PegasusTokenizer
from .tokenization_reformer import ReformerTokenizer
from .tokenization_retribert import RetriBertTokenizer, RetriBertTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
@ -79,6 +81,7 @@ TOKENIZER_MAPPING = OrderedDict(
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)),
(CamembertConfig, (CamembertTokenizer, None)),
(PegasusConfig, (PegasusTokenizer, None)),
(MBartConfig, (MBartTokenizer, None)),
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(MarianConfig, (MarianTokenizer, None)),

View File

@ -0,0 +1,193 @@
# coding=utf-8
# Copyright 2020 Google and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional
from transformers.tokenization_reformer import ReformerTokenizer
from .tokenization_utils_base import BatchEncoding
class PegasusTokenizer(ReformerTokenizer):
offset = 103 # entries 2-104 are only used for pretraining
vocab_files_names = {"vocab_file": "spiece.model"}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Dont use reserved words added_token_encoder, added_tokens_decoder because of
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
assert len(self.added_tokens_decoder) == 0
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
# entries 2-104 are only used for pretraining and called unk_2, ...unk_104
self.encoder.update({i: f"unk_{i}" for i in range(2, self.offset + 2)})
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) in an id using the vocab. """
if token in self.decoder:
return self.decoder[token]
elif token in self.added_tokens_decoder:
return self.added_tokens_decoder[token]
sp_id = self.sp_model.piece_to_id(token)
return sp_id + self.offset
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.encoder:
return self.encoder[index]
elif index in self.added_tokens_encoder:
return self.added_tokens_encoder[index]
else:
# assert index > self.offset, f"cannot decode ids between 2 and {self.offset}. Got {index}"
token = self.sp_model.IdToPiece(index - self.offset)
return token
@property
def vocab_size(self) -> int:
return len(self.sp_model) + self.offset
def get_vocab(self) -> Dict[str, int]:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def num_special_tokens_to_add(self, pair=False):
"""Just EOS"""
return 1
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
assert all_special_ids == set([0, 1])
return [1 if x in all_special_ids else 0 for x in seq]
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
if already_has_special_tokens:
return self._special_token_mask(token_ids_0)
elif token_ids_1 is None:
return self._special_token_mask(token_ids_0) + [1]
else:
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""
Build model inputs from a sequence by adding eos to the end. no bos token is added to the front.
- single sequence: ``X </s>``
- pair of sequences: ``A B </s>`` (not intended use)
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id]
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: str = "pt",
truncation=True,
padding="longest",
) -> BatchEncoding:
"""
Prepare model inputs for summarization or translation.
Arguments:
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
tokenizer_kwargs = dict(
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
padding=padding,
)
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v
return model_inputs

View File

@ -45,6 +45,7 @@ if is_torch_available():
_prepare_bart_decoder_inputs,
SinusoidalPositionalEmbedding,
)
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
@require_torch
@ -478,7 +479,6 @@ class BartModelIntegrationTests(unittest.TestCase):
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",

View File

@ -23,14 +23,13 @@ RO_CODE = 250020
@require_torch
class AbstractMBartIntegrationTest(unittest.TestCase):
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
maxDiff = 1000 # longer string compare tracebacks
checkpoint_name = None
@classmethod
def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
cls.pad_token_id = 1
return cls
@cached_property
@ -43,7 +42,7 @@ class AbstractMBartIntegrationTest(unittest.TestCase):
@require_torch
class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name = "facebook/mbart-large-en-ro"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
@ -73,7 +72,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
]
),
}
net_input["attention_mask"] = net_input["input_ids"].ne(self.pad_token_id)
net_input["attention_mask"] = net_input["input_ids"].ne(1)
with torch.no_grad():
logits, *other_stuff = model(**net_input)
@ -125,7 +124,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
@require_torch
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name = "facebook/mbart-large-cc25"
src_text = [
" UN Chief Says There Is No Military Solution in Syria",

View File

@ -0,0 +1,79 @@
import unittest
from transformers import AutoConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bart import PGE_ARTICLE
from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
if is_torch_available():
from transformers import AutoModelForSeq2SeqLM
XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """
@require_torch
class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name = "google/pegasus-xsum"
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
tgt_text = [
"California's largest electricity provider has turned off power to tens of thousands of customers.",
"N-Dubz have revealed they weren't expecting to get four nominations at this year's Mobo Awards.",
]
@cached_property
def model(self):
return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
@slow
def test_pegasus_xsum_summary(self):
assert self.tokenizer.model_max_length == 512
inputs = self.tokenizer(self.src_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(
torch_device
)
assert inputs.input_ids.shape == (2, 421)
translated_tokens = self.model.generate(**inputs)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text, decoded)
if "cuda" not in torch_device:
return
# Demonstrate fp16 issue, Contributions welcome!
self.model.half()
translated_tokens_fp16 = self.model.generate(**inputs, max_length=10)
decoded = self.tokenizer.batch_decode(translated_tokens_fp16, skip_special_tokens=True)
bad_fp16_result = ["unk_7unk_7unk_7unk_7unk_7unk_7unk_7", "unk_7unk_7unk_7unk_7unk_7unk_7unk_7"]
self.assertListEqual(decoded, bad_fp16_result)
class PegasusConfigTests(unittest.TestCase):
def test_all_config_max_lengths(self):
expected_max_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
}
failures = []
pegasus_prefix = "google/pegasus"
for dataset, max_len in expected_max_length.items():
mname = f"{pegasus_prefix}-{dataset}"
cfg = AutoConfig.from_pretrained(mname)
if cfg.max_length != max_len:
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
if failures == []:
return
# error
all_fails = "\n".join(failures)
raise AssertionError(f"The following configs have unexpected settings: {all_fails}")

View File

@ -0,0 +1,69 @@
import unittest
from pathlib import Path
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
from transformers.tokenization_pegasus import PegasusTokenizer
from .test_tokenization_common import TokenizerTesterMixin
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = PegasusTokenizer
def setUp(self):
super().setUp()
save_dir = Path(self.tmpdirname)
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
if not (save_dir / spm_file).exists():
tokenizer = self.pegasus_large_tokenizer
tokenizer.save_pretrained(self.tmpdirname)
@cached_property
def pegasus_large_tokenizer(self):
return PegasusTokenizer.from_pretrained("google/pegasus-large")
@unittest.skip("add_tokens does not work yet")
def test_swap_special_token(self):
pass
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
if not kwargs:
return self.pegasus_large_tokenizer
else:
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
return ("This is a test", "This is a test")
def test_pegasus_large_tokenizer_settings(self):
tokenizer = self.pegasus_large_tokenizer
# The tracebacks for the following asserts are **better** without messages or self.assertEqual
assert tokenizer.vocab_size == 96103
assert tokenizer.pad_token_id == 0
assert tokenizer.eos_token_id == 1
assert tokenizer.offset == 103
assert tokenizer.unk_token_id == tokenizer.offset + 2 == 105
assert tokenizer.unk_token == "<unk>"
assert tokenizer.mask_token is None
assert tokenizer.mask_token_id is None
assert tokenizer.model_max_length == 1024
raw_input_str = "To ensure a smooth flow of bank resolutions."
desired_result = [413, 615, 114, 2291, 1971, 113, 1679, 10710, 107, 1]
ids = tokenizer([raw_input_str], return_tensors=None).input_ids[0]
self.assertListEqual(desired_result, ids)
assert tokenizer.convert_ids_to_tokens([0, 1, 2]) == ["<pad>", "</s>", "unk_2"]
@require_torch
def test_pegasus_large_seq2seq_truncation(self):
src_texts = ["This is going to be way too long" * 10000, "short example"]
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
assert batch.input_ids.shape == (2, 1024)
assert batch.attention_mask.shape == (2, 1024)
assert "decoder_input_ids" in batch # because tgt_texts was specified
assert batch.decoder_input_ids.shape == (2, 5)
assert batch.decoder_attention_mask.shape == (2, 5)
assert len(batch) == 4 # no extra keys