diff --git a/docs/source/index.rst b/docs/source/index.rst index f9ff1a0606c..215e6cba6c5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,4 +99,5 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train model_doc/camembert model_doc/albert model_doc/xlmroberta - model_doc/flaubert \ No newline at end of file + model_doc/flaubert + model_doc/bart diff --git a/docs/source/model_doc/bart.rst b/docs/source/model_doc/bart.rst new file mode 100644 index 00000000000..a034f3b57a2 --- /dev/null +++ b/docs/source/model_doc/bart.rst @@ -0,0 +1,52 @@ +Bart +---------------------------------------------------- +**DISCLAIMER:** This model is still a work in progress, if you see something strange, +file a `Github Issue `__ and assign +@sshleifer + +The Bart model was `proposed `_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer on 29 Oct, 2019. +It is a sequence to sequence model where both encoder and decoder are transformers. The paper also introduces a novel pretraining objective, and demonstrates excellent summarization results. +The authors released their code `here `_ + +**Abstract:** + +*We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text. It uses a standard Tranformer-based neural machine translation architecture which, despite its simplicity, can be seen as generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder), and many other more recent pretraining schemes. We evaluate a number of noising approaches, finding the best performance by both randomly shuffling the order of the original sentences and using a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides a 1.1 BLEU increase over a back-translation system for machine translation, with only target language pretraining. We also report ablation experiments that replicate other pretraining schemes within the BART framework, to better measure which factors most influence end-task performance.* +`BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension` + + +Notes: +- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting. +- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs. +- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space. + +BartModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BartModel + :members: forward + + +BartForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BartForMaskedLM + :members: forward + + +BartForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BartForSequenceClassification + :members: forward + +BartConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BartConfig + :members: + +Automatic Creation of Decoder Inputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +This is enabled by default + +.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index d708054f41f..4120f88dc16 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -275,6 +275,13 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | | | FlauBERT large architecture | | | | (see `details `__) | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| Bart | ``bart-large`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters | +| | | (see `details `_) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters | +| | | | bart-large base architecture with a classification head | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ .. `__ diff --git a/examples/summarization/modeling_bertabs.py b/examples/summarization/modeling_bertabs.py index bad412baac1..0691403186c 100644 --- a/examples/summarization/modeling_bertabs.py +++ b/examples/summarization/modeling_bertabs.py @@ -303,7 +303,7 @@ class TransformerDecoderLayer(nn.Module): self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.drop = nn.Dropout(dropout) mask = self._get_attn_subsequent_mask(MAX_SIZE) - # Register self.mask as a buffer in TransformerDecoderLayer, so + # Register self.mask as a saved_state in TransformerDecoderLayer, so # it gets TransformerDecoderLayer's cuda behavior automatically. self.register_buffer("mask", mask) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c5b36917e08..93bd94c6227 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -21,6 +21,7 @@ import logging from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig +from .configuration_bart import BartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -106,6 +107,7 @@ from .pipelines import ( ) from .tokenization_albert import AlbertTokenizer from .tokenization_auto import AutoTokenizer +from .tokenization_bart import BartTokenizer from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer @@ -204,6 +206,7 @@ if is_torch_available(): XLMForQuestionAnsweringSimple, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, ) + from .modeling_bart import BartForSequenceClassification, BartModel, BartForMaskedLM from .modeling_roberta import ( RobertaForMaskedLM, RobertaModel, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 4fd23fee260..3b112704cca 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -19,6 +19,7 @@ import logging from collections import OrderedDict from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig +from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig @@ -42,6 +43,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( (key, value) for pretrained_map in [ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + BART_PRETRAINED_CONFIG_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -67,6 +69,7 @@ CONFIG_MAPPING = OrderedDict( ("albert", AlbertConfig,), ("camembert", CamembertConfig,), ("xlm-roberta", XLMRobertaConfig,), + ("bart", BartConfig,), ("roberta", RobertaConfig,), ("flaubert", FlaubertConfig,), ("bert", BertConfig,), diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py new file mode 100644 index 00000000000..2e096c5501c --- /dev/null +++ b/src/transformers/configuration_bart.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BART configuration """ + + +import logging + +from .configuration_utils import PretrainedConfig + + +logger = logging.getLogger(__name__) + +_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json" +BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "bart-large": _bart_large_url, + "bart-large-mnli": _bart_large_url, # fine as same + "bart-cnn": None, # not done +} + + +class BartConfig(PretrainedConfig): + r""" + Configuration class for Bart. Parameters are renamed from the fairseq implementation + """ + model_type = "bart" + pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__( + self, + activation_dropout=0.0, + vocab_size=50265, + pad_token_id=1, + eos_token_id=2, + d_model=1024, + encoder_ffn_dim=4096, + encoder_layers=12, + encoder_attention_heads=16, + decoder_ffn_dim=4096, + decoder_layers=12, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + attention_dropout=0.0, + dropout=0.1, + max_position_embeddings=1024, + init_std=0.02, + classifier_dropout=0.0, + output_past=False, + num_labels=3, + **common_kwargs + ): + r""" + :class:`~transformers.BartConfig` is the configuration class for `BartModel`. + Examples: + config = BartConfig.from_pretrained('bart-large') + model = BartModel(config) + """ + super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs) + + self.vocab_size = vocab_size + self.d_model = d_model # encoder_embed_dim and decoder_embed_dim + self.eos_token_id = eos_token_id + + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = self.num_hidden_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.max_position_embeddings = max_position_embeddings + self.init_std = init_std # Normal(0, this parameter) + + # 3 Types of Dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.dropout = dropout + + # Classifier stuff + self.classif_dropout = classifier_dropout + + @property + def num_attention_heads(self): + return self.encoder_attention_heads + + @property + def hidden_size(self): + return self.d_model diff --git a/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..6a9403aea4e --- /dev/null +++ b/src/transformers/convert_bart_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BART checkpoint.""" + + +import argparse +import logging +from pathlib import Path + +import fairseq +import torch +from packaging import version + +from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer + + +if version.parse(fairseq.__version__) < version.parse("0.9.0"): + raise Exception("requires fairseq >= 0.9.0") + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +SAMPLE_TEXT = "Hello world! cécé herlolip" + +rename_keys = [ + ("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"), + ("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"), + ("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"), + ("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"), +] +IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our BERT structure. + """ + b2 = torch.hub.load("pytorch/fairseq", checkpoint_path) + b2.eval() # disable dropout + b2.model.upgrade_state_dict(b2.model.state_dict()) + config = BartConfig() + tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0) + tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0) + assert torch.eq(tokens, tokens2).all() + + # assert their_output.size() == (1, 11, 1024) + + if checkpoint_path == "bart.large": + state_dict = b2.model.state_dict() + state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] + model = BartModel(config) + their_output = b2.extract_features(tokens) + + else: # MNLI Case + state_dict = b2.state_dict() + state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"] + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + state_dict.pop("_float_tensor", None) + model = BartForSequenceClassification(config) + their_output = b2.predict("mnli", tokens, return_logits=True) + for k in IGNORE_KEYS: + state_dict.pop(k, None) + model.load_state_dict(state_dict) + model.eval() + our_outputs = model.forward(tokens)[0] + + assert their_output.shape == our_outputs.shape + assert (their_output == our_outputs).all().item() + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="") + parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_bart_checkpoint( + args.fairseq_path, args.pytorch_dump_folder_path, + ) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index fbc8bc03ad3..ae7d88d5a32 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -21,6 +21,7 @@ from collections import OrderedDict from .configuration_auto import ( AlbertConfig, AutoConfig, + BartConfig, BertConfig, CamembertConfig, CTRLConfig, @@ -43,6 +44,7 @@ from .modeling_albert import ( AlbertForSequenceClassification, AlbertModel, ) +from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel from .modeling_bert import ( BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForMaskedLM, @@ -118,6 +120,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( (key, value) for pretrained_map in [ BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BART_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, @@ -142,6 +145,7 @@ MODEL_MAPPING = OrderedDict( (AlbertConfig, AlbertModel), (CamembertConfig, CamembertModel), (XLMRobertaConfig, XLMRobertaModel), + (BartConfig, BartModel), (RobertaConfig, RobertaModel), (BertConfig, BertModel), (OpenAIGPTConfig, OpenAIGPTModel), @@ -161,6 +165,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( (AlbertConfig, AlbertForMaskedLM), (CamembertConfig, CamembertForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM), + (BartConfig, BartForMaskedLM), (RobertaConfig, RobertaForMaskedLM), (BertConfig, BertForPreTraining), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), @@ -180,6 +185,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( (AlbertConfig, AlbertForMaskedLM), (CamembertConfig, CamembertForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM), + (BartConfig, BartForMaskedLM), (RobertaConfig, RobertaForMaskedLM), (BertConfig, BertForMaskedLM), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), @@ -198,6 +204,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (AlbertConfig, AlbertForSequenceClassification), (CamembertConfig, CamembertForSequenceClassification), (XLMRobertaConfig, XLMRobertaForSequenceClassification), + (BartConfig, BartForSequenceClassification), (RobertaConfig, RobertaForSequenceClassification), (BertConfig, BertForSequenceClassification), (XLNetConfig, XLNetForSequenceClassification), diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py new file mode 100644 index 00000000000..3ec76cf675b --- /dev/null +++ b/src/transformers/modeling_bart.py @@ -0,0 +1,1028 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model, ported from the fairseq repo.""" + +import logging +import random +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .configuration_bart import BartConfig +from .file_utils import add_start_docstrings, add_start_docstrings_to_callable +from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids + + +logger = logging.getLogger(__name__) + + +BART_PRETRAINED_MODEL_ARCHIVE_MAP = { + "bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/pytorch_model.bin", + "bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/pytorch_model.bin", +} + +BART_START_DOCSTRING = r""" + + This model is a PyTorch `torch.nn.Module `_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matters related to general usage and behavior. + + Parameters: + config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. + +""" + +BART_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Use BartTokenizer.encode to produce them. + Padding will be ignored by default should you provide it. + Indices can be obtained using :class:`transformers.BartTokenizer.encode(text)`. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Mask to avoid performing attention on padding token indices in input_ids. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): + Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper. + decoder_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, 1, tgt_seq_len, tgt_seq_len)`, `optional`, defaults to :obj:`None`): + Default behavior: generate a tensor that ignores pad tokens and future tokens, as in the paper. + If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. + See diagram 1 in the paper for more info on the default strategy +""" +LARGE_NEGATIVE = -1e4 + + +def _prepare_bart_decoder_inputs( + config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, +): + """Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if + none are provided. This mimics the default behavior in fairseq. To override it pass in masks. + """ + pad_token_id = config.pad_token_id + need_causal_mask = not config.output_past + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right(input_ids, pad_token_id) + bsz, tgt_len = decoder_input_ids.size()[:2] + if decoder_attn_mask is None: + decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) + if need_causal_mask: + causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1) + else: + causal_lm_mask = None + new_shape = (bsz, tgt_len, tgt_len) + # make it broadcastable so can just be added to the attention coefficients + decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape) + assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len) + return decoder_input_ids, decoder_attn_mask + + +class PretrainedBartModel(PreTrainedModel): + config_class = BartConfig + base_model_prefix = "model" + pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP + + def _init_weights(self, module): + std = self.config.init_std + + # called init_bert_params in fairseq + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = 1 + input_ids = torch.Tensor( + [ + [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], + [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 2, pad_token], + ] + ).long() + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs( + self.config, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attn_mask=None + ) + dummy_inputs = { + "decoder_input_ids": decoder_input_ids, + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + "decoder_attention_mask": decoder_attn_mask, + } + return dummy_inputs + + +def _make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data # .T + return lin_layer + + +# Helper Functions, mostly for making masks +def _check_shapes(shape_1, shape2): + if shape_1 != shape2: + raise AssertionError("shape mismatch: {} != {}".format(shape_1, shape2)) + + +def _combine_masks(key_padding_mask, attn_mask, targ_size): + # targ_size = (bsz, tgt_len, src_len) + a = torch.zeros(targ_size) + b = torch.zeros(targ_size) + if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size + _check_shapes(key_padding_mask.shape, targ_size[:2]) + reshaped = key_padding_mask.unsqueeze(2).expand(*targ_size) + a[reshaped] = 1e-8 + + if attn_mask is not None: # (tgt_len, src_len) -> targ_size + _check_shapes(attn_mask.shape, targ_size[-2:]) + b = attn_mask.unsqueeze(0).expand(*targ_size) + return (a + b).unsqueeze(1).clamp(LARGE_NEGATIVE,) + + +def shift_tokens_right(input_ids, pad_token_id): + """Shift input ids one token to the right, and wrap the last non pad token (usually ).""" + prev_output_tokens = input_ids.clone() + index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = input_ids[:, :-1] + return prev_output_tokens + + +def make_padding_mask(input_ids, padding_idx=1): + """True for pad tokens""" + padding_mask = input_ids.eq(padding_idx) + if not padding_mask.any(): + padding_mask = None + return padding_mask + + +# Helper Modules + + +class EncoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + self.output_attentions = config.output_attentions + self.self_attn = SelfAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = F.gelu + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def forward(self, x, encoder_padding_mask): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, src_len)` where padding elements are indicated by ``1``. + for t_tgt, t_src is excluded (or masked out), =0 means it is + included in attention + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + residual = x + x, attn_weights = self.self_attn.forward( + query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.final_layer_norm(x) + return x, attn_weights + + +class BartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer + is a :class:`EncoderLayer`. + + Args: + config: BartConfig + """ + + def __init__(self, config: BartConfig, embed_tokens): + super().__init__() + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = embed_tokens + + self.embed_positions = LearnedPositionalEmbedding(config.max_position_embeddings, embed_dim, self.padding_idx,) + self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = LayerNorm(embed_dim) + + def forward( + self, input_ids=None, attention_mask=None, + ): + """ + Args: + input_ids (LongTensor): tokens in the source language of shape + `(batch, src_len)` + attention_mask (torch.LongTensor): indicating which indices are padding tokens. + Returns: + namedtuple: + - **x** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + - **all_attentions** (List[Tensor]): Attention weights for each layer. + During training might not be of length n_layers because of layer dropout. + """ + inputs_embeds = self.embed_tokens(input_ids) + embed_pos = self.embed_positions(input_ids) + x = inputs_embeds + embed_pos + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + encoder_states, all_attentions = [], [] + + # encoder layers + for encoder_layer in self.layers: + + if self.output_hidden_states: + encoder_states.append(x) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + attn = None + else: + x, attn = encoder_layer.forward(x, attention_mask) + + if self.output_attentions: + all_attentions.append(attn) + + if self.output_hidden_states: + encoder_states.append(x) + + encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] + + return x, encoder_states, all_attentions + + +class DecoderLayer(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = SelfAttention( + embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = F.gelu + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.encoder_attn = SelfAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + encoder_decoder_attention=True, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def forward( + self, + x, + encoder_hidden_states, + encoder_attn_mask=None, + decoder_cached_states=None, + attention_mask=None, + need_attn_weights=False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attn_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if decoder_cached_states is None: + prev_self_attn_state, prev_attn_state = (None, None) + else: + assert len(decoder_cached_states) == 3 + prev_self_attn_state, prev_attn_state = ( + decoder_cached_states["self"], + decoder_cached_states["encoder_decoder"], + ) + + residual = x + if prev_self_attn_state is not None: + saved_state = prev_self_attn_state + decoder_cached_states["self"] = saved_state + y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it + + x, self_attn_weights = self.self_attn.forward( + query=x, + key=y, + value=y, + decoder_cached_states=decoder_cached_states, + need_weights=need_attn_weights, + attn_mask=attention_mask, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.self_attn_layer_norm(x) + residual = x + assert self.encoder_attn.cache_key != self.self_attn.cache_key + if prev_attn_state is not None: + saved_state = prev_attn_state + decoder_cached_states["encoder_decoder"] = saved_state + x, encoder_attn_weights = self.encoder_attn.forward( + query=x, + key=encoder_hidden_states, # could be None + value=encoder_hidden_states, + key_padding_mask=encoder_attn_mask, + decoder_cached_states=decoder_cached_states, + static_kv=True, + need_weights=False, # not returning it so why compute it + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + + x = self.encoder_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.final_layer_norm(x) + return ( + x, + self_attn_weights, + decoder_cached_states, + ) # just self_attn weights for now, following t5, decoder_cached_states = cache for decoding + + def _past_to_dict(self, prev_attn_state): + prev_key, prev_value = prev_attn_state[:2] + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + return saved_state + + +class BartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer + is a :class:`DecoderLayer`. + Args: + config: BartConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): + super().__init__() + self.output_past = config.output_past + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = config.max_position_embeddings + self.embed_tokens = embed_tokens + self.embed_positions = LearnedPositionalEmbedding( + config.max_position_embeddings, config.d_model, self.padding_idx, + ) + self.layers = nn.ModuleList( + [DecoderLayer(config) for _ in range(config.decoder_layers)] + ) # type: List[DecoderLayer] + self.layernorm_embedding = LayerNorm(config.d_model) + + def forward( + self, + input_ids, + encoder_hidden_states, + encoder_padding_mask, + combined_mask, + decoder_cached_states=None, + **unused + ): + """ + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + input_ids (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_hidden_states: output from the encoder, used for + encoder-side attention + encoder_padding_mask: for ignoring pad tokens + decoder_cached_states (dict or None): dictionary used for storing state during generation + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - hidden states + - attentions + """ + # embed positions + positions = self.embed_positions(input_ids) + + if decoder_cached_states is not None: + input_ids = input_ids[:, -1:] + positions = positions[:, -1:] + x = self.embed_tokens(input_ids) + + if positions is not None: + x += positions + + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = x.transpose(0, 1) # (seq_len, BS, model_dim) + # decoder layers + all_hidden_states = () + all_self_attns = () + next_decoder_cache = [] + + for i, decoder_layer in enumerate(self.layers): + decoder_layer # type: DecoderLayer + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability > self.layerdrop): + continue + layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None + x, layer_self_attn, layer_past = decoder_layer.forward( + x, + encoder_hidden_states, + encoder_padding_mask, + decoder_cached_states=layer_state, + attention_mask=combined_mask, + need_attn_weights=self.output_attentions, + ) + if self.output_past: + next_decoder_cache.append(layer_past) + if self.output_hidden_states: + all_hidden_states += (x,) + if self.output_attentions: + all_self_attns += (layer_self_attn,) + + # Convert shapes from (seq_len, BS, model_dim) to (BS, seq_len, model_dim) + all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states] + x = x.transpose(0, 1) + + return x, next_decoder_cache, all_hidden_states, list(all_self_attns) + + +class SelfAttention(nn.Module): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + encoder_decoder_attention=False, # otherwise self_attention + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.encoder_decoder_attention = encoder_decoder_attention + qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim # True for all BART + + assert self.encoder_decoder_attention or qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self" + + def _shape(self, tensor, dim_0, bsz): + return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + decoder_cached_states: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = False, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time(SeqLen) x Batch x Channel + + Args: + + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + """ + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + # get here for encoder decoder cause of static_kv + if decoder_cached_states is not None: # get the last k,v and mask for reuse + saved_state = decoder_cached_states.get(self.cache_key, {}) + if "prev_key" in saved_state: + # previous time steps are cached - no need to recompute key and value if they are static + if static_kv: + assert self.encoder_decoder_attention + key = value = None + else: + saved_state = None + + q = self.q_proj(query) * self.scaling + if self.encoder_decoder_attention: + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + else: + k = self.k_proj(query) + v = self.v_proj(query) + + q = self._shape(q, tgt_len, bsz) + if k is not None: + k = self._shape(k, -1, bsz) + if v is not None: + v = self._shape(v, -1, bsz) + + if saved_state is not None: + k, v, key_padding_mask, new_state = self._use_and_update_saved_state( + k, v, saved_state, key_padding_mask, static_kv, bsz + ) + saved_state.update( + { + "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), + "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), + "prev_key_padding_mask": key_padding_mask, + } + ) + decoder_cached_states[self.cache_key] = saved_state # Update cache + assert k is not None + src_len = k.size(1) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) + + if attn_mask is not None: + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # This is part of a workaround to get around fork/join parallelism not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len) + + if key_padding_mask is not None: # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool) + attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,) + assert v is not None + attn_output = torch.bmm(attn_probs, v) + assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = self.out_proj(attn_output) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + return attn_output, attn_weights + + def _use_and_update_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + assert k is not None and v is not None + prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None) # type: Optional[Tensor] + key_padding_mask = self._cat_prev_key_padding_mask( + key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv + ) + return k, v, key_padding_mask, saved_state + + @staticmethod + def _cat_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current is None + elif prev_key_padding_mask is not None: + + filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)) + if prev_key_padding_mask.is_cuda: + filler = filler.cuda() + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + elif key_padding_mask is not None: + filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) + if key_padding_mask.is_cuda: + filler = filler.cuda() + new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + +class BartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + # This can trivially be shared with RobertaClassificationHead + + def __init__( + self, input_dim, inner_dim, num_classes, pooler_dropout, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, x): + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class LearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + Padding ids are ignored by either offsetting based on padding_idx + or by setting padding_idx to None and ensuring that the appropriate + position ids are passed to the forward function. + """ + + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, + ): + # if padding_idx is specified then offset the embedding ids by + # this index and adjust num_embeddings appropriately + assert padding_idx is not None + num_embeddings += padding_idx + 1 # WHY? + super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + + def forward(self, input): + """Input is expected to be of size [bsz x seqlen].""" + positions = create_position_ids_from_input_ids(input, self.padding_idx) + return super().forward(positions) + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): + if torch.cuda.is_available(): + try: + from apex.normalization import FusedLayerNorm + + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a input_ids with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _filter_out_falsey_values(tup) -> Tuple: + """Remove entries that are None or [] from an iterable.""" + return tuple(x for x in tup if isinstance(x, torch.Tensor) or x) + + +RET_DOCSTRING = r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. +""" +# Public API + + +@add_start_docstrings( + "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, +) +class BartModel(PretrainedBartModel): + def __init__(self, config: BartConfig): + super().__init__(config) + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BartEncoder(config, self.shared) + self.decoder = BartDecoder(config, self.shared) + + self.init_weights() + + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) + def forward( + self, + input_ids, + attention_mask=None, + decoder_input_ids=None, + encoder_outputs=None, # type: Tuple + decoder_attention_mask=None, + decoder_cached_states=None, + ): + if attention_mask is not None: + assert attention_mask.dim() == 2 + + attention_mask = (1.0 - attention_mask.long()) * -10000.0 + assert attention_mask.max() <= 0 + + # make masks if user doesn't supply + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs( + self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask, + ) + + assert decoder_input_ids is not None + if encoder_outputs is None: + # TODO(SS): make this caching more usable when overwrite generate + encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask) + assert isinstance(encoder_outputs, tuple) + # dec_features, decoder_cached_states, dec_hidden, dec_attn + decoder_outputs = self.decoder.forward( + decoder_input_ids, + encoder_outputs[0], + attention_mask, + decoder_attn_mask, + decoder_cached_states=decoder_cached_states, + ) + # Attention and hidden_states will be [] or None if they aren't needed + decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple + assert isinstance(decoder_outputs[0], torch.Tensor) + encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple + return decoder_outputs + encoder_outputs + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + + def get_output_embeddings(self): + return _make_linear_from_emb(self.shared) + + +@add_start_docstrings( + "The bare BART Model with a language modeling head", BART_START_DOCSTRING, +) +class BartForMaskedLM(PretrainedBartModel): + base_model_prefix = "model" + + def __init__(self, config: BartConfig): + super().__init__(config) + self.model = BartModel(config) + self.lm_head = _make_linear_from_emb(self.model.shared) + + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) + def forward( + self, + input_ids, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + decoder_cached_states=None, + lm_labels=None, + **unused + ): + r""" + masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + Labels for computing the masked language modeling loss. + Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). + Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens + with labels + in ``[0, ..., config.vocab_size]``. + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + tokenizer = BartTokenizer.from_pretrained('bart-large') + model = BartForMaskedLM.from_pretrained('bart-large') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids=input_ids, lm_labels=input_ids) + loss, prediction_scores = outputs[:2] + """ + outputs = self.model.forward( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + decoder_cached_states=decoder_cached_states, + ) + lm_logits = self.lm_head.forward(outputs[0]) + outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here + if lm_labels is not None: + loss_fct = nn.CrossEntropyLoss() + # TODO(SS): do we need to ignore pad tokens in lm_labels? + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), lm_labels.view(-1)) + outputs = (masked_lm_loss,) + outputs + + return outputs + + @staticmethod + def prepare_inputs_for_generation(input_ids, past, **kwargs): + return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids} + + def get_output_embeddings(self): + return self.lm_head + + +@add_start_docstrings( + """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, + BART_START_DOCSTRING, +) +class BartForSequenceClassification(PretrainedBartModel): + def __init__(self, config: BartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BartModel(config) + self.classification_head = BartClassificationHead( + config.d_model, config.d_model, config.num_labels, config.classif_dropout, + ) + self.model._init_weights(self.classification_head.dense) + self.model._init_weights(self.classification_head.out_proj) + + @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) + def forward( + self, + input_ids, + attention_mask=None, + encoder_outputs=None, + decoder_input_ids=None, + decoder_attention_mask=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): + Classification loss (cross entropy) + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention + heads. + + Examples:: + + from transformers import BartTokenizer, BartForSequenceClassification + import torch + + tokenizer = BartTokenizer.from_pretrained('bart-large') + model = BartForSequenceClassification.from_pretrained('bart-large') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", + add_special_tokens=True)).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, logits = outputs[:2] + + """ + outputs = self.model.forward( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + x = outputs[0] # last hidden state + eos_mask = input_ids.eq(self.config.eos_token_id) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + logits = self.classification_head(sentence_representation) + # Prepend logits + outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here + if labels is not None: # prepend loss to output, + loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 95d56cf6acb..7f1a71f2f24 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -236,42 +236,6 @@ class PreTrainedEncoderDecoder(nn.Module): return decoder_outputs + encoder_outputs - @staticmethod - def prepare_model_kwargs(**kwargs): - """ Prepare the encoder and decoder's keyword arguments. - - Keyword arguments come in 3 flavors: - - encoder-specific (prefixed by `encoder_`) - - decoder-specific (prefixed by `decoder_`) - - those that apply to the model as whole. - - We let the specific kwargs override the common ones in case of - conflict. - """ - kwargs_common = { - argument: value - for argument, value in kwargs.items() - if not argument.startswith("encoder_") and not argument.startswith("decoder_") - } - decoder_kwargs = kwargs_common.copy() - encoder_kwargs = kwargs_common.copy() - encoder_kwargs.update( - { - argument[len("encoder_") :]: value - for argument, value in kwargs.items() - if argument.startswith("encoder_") - } - ) - decoder_kwargs.update( - { - argument[len("decoder_") :]: value - for argument, value in kwargs.items() - if argument.startswith("decoder_") - } - ) - decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None) - return encoder_kwargs, decoder_kwargs - class Model2Model(PreTrainedEncoderDecoder): r""" diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 50de77b85c1..e86e6864ece 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss, MSELoss from .configuration_roberta import RobertaConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu +from .modeling_utils import create_position_ids_from_input_ids logger = logging.getLogger(__name__) @@ -56,7 +57,7 @@ class RobertaEmbeddings(BertEmbeddings): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = self.create_position_ids_from_input_ids(input_ids).to(input_ids.device) + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) else: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) @@ -64,18 +65,6 @@ class RobertaEmbeddings(BertEmbeddings): input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds ) - def create_position_ids_from_input_ids(self, x): - """ Replace non-padding symbols with their position numbers. Position numbers begin at - padding_idx+1. Padding symbols are ignored. This is modified from fairseq's - `utils.make_positions`. - - :param torch.Tensor x: - :return torch.Tensor: - """ - mask = x.ne(self.padding_idx).long() - incremental_indicies = torch.cumsum(mask, dim=1) * mask - return incremental_indicies + self.padding_idx - def create_position_ids_from_inputs_embeds(self, inputs_embeds): """ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e272be89eca..3d2b7d6db41 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1448,6 +1448,20 @@ class SequenceSummary(nn.Module): return output +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ Replace non-padding symbols with their position numbers. Position numbers begin at + padding_idx+1. Padding symbols are ignored. This is modified from fairseq's + `utils.make_positions`. + + :param torch.Tensor x: + :return torch.Tensor: + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask + return incremental_indicies.long() + padding_idx + + def prune_linear_layer(layer, index, dim=0): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index f708e557288..9d1c85d5905 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -21,6 +21,7 @@ from collections import OrderedDict from .configuration_auto import ( AlbertConfig, AutoConfig, + BartConfig, BertConfig, CamembertConfig, CTRLConfig, @@ -37,6 +38,7 @@ from .configuration_auto import ( ) from .configuration_utils import PretrainedConfig from .tokenization_albert import AlbertTokenizer +from .tokenization_bart import BartTokenizer from .tokenization_bert import BertTokenizer, BertTokenizerFast from .tokenization_bert_japanese import BertJapaneseTokenizer from .tokenization_camembert import CamembertTokenizer @@ -63,6 +65,7 @@ TOKENIZER_MAPPING = OrderedDict( (AlbertConfig, (AlbertTokenizer, None)), (CamembertConfig, (CamembertTokenizer, None)), (XLMRobertaConfig, (XLMRobertaTokenizer, None)), + (BartConfig, (BartTokenizer, None)), (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), (BertConfig, (BertTokenizer, BertTokenizerFast)), (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py new file mode 100644 index 00000000000..ef2631a3527 --- /dev/null +++ b/src/transformers/tokenization_bart.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tokenization_roberta import RobertaTokenizer + + +# vocab and merges same as roberta +vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json" +merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" +_all_bart_models = [ + "bart-large", + "bart-large-mnli", + # "bart-large-cnn" +] + + +class BartTokenizer(RobertaTokenizer): + # merges and vocab same as Roberta + max_model_input_sizes = {m: 1024 for m in _all_bart_models} + pretrained_vocab_files_map = { + "vocab_file": {m: vocab_url for m in _all_bart_models}, + "merges_file": {m: merges_url for m in _all_bart_models}, + } diff --git a/src/transformers/utils_encoder_decoder.py b/src/transformers/utils_encoder_decoder.py new file mode 100644 index 00000000000..4c32622d46e --- /dev/null +++ b/src/transformers/utils_encoder_decoder.py @@ -0,0 +1,47 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Classes to support Encoder-Decoder architectures """ + + +def prepare_encoder_decoder_model_kwargs(**kwargs): + """ Prepare the encoder and decoder's keyword arguments. + + Keyword arguments come in 3 flavors: + - encoder-specific (prefixed by `encoder_`) + - decoder-specific (prefixed by `decoder_`) + - those that apply to the model as whole. + + We let the specific kwargs override the common ones in case of + conflict. + """ + + kwargs_common = { + argument: value + for argument, value in kwargs.items() + if not argument.startswith("encoder_") and not argument.startswith("decoder_") + } + if "input_ids" in kwargs_common: + kwargs["encoder_input_ids"] = kwargs_common.pop("input_ids") + + decoder_kwargs = kwargs_common.copy() + encoder_kwargs = kwargs_common.copy() + encoder_kwargs.update( + {argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")} + ) + decoder_kwargs.update( + {argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")} + ) + decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None) + return encoder_kwargs, decoder_kwargs diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py new file mode 100644 index 00000000000..61f60986107 --- /dev/null +++ b/tests/test_modeling_bart.py @@ -0,0 +1,343 @@ +# coding=utf-8 +# Copyright 2020 Huggingface +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tempfile +import unittest + +from transformers import is_torch_available + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor +from .utils import CACHE_DIR, require_torch, slow, torch_device + + +if is_torch_available(): + import torch + from transformers import ( + AutoModelForSequenceClassification, + BartModel, + BartForMaskedLM, + BartForSequenceClassification, + BartConfig, + ) + from transformers.modeling_bart import ( + BART_PRETRAINED_MODEL_ARCHIVE_MAP, + shift_tokens_right, + _prepare_bart_decoder_inputs, + ) + from transformers.tokenization_bart import BartTokenizer + + +@require_torch +class ModelTester: + def __init__( + self, parent, + ): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_labels = False + self.vocab_size = 99 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 12 + torch.manual_seed(0) + + def prepare_config_and_inputs_for_common(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3,) + input_ids[:, -1] = 2 # Eos Token + + config = BartConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + ) + inputs_dict = prepare_bart_inputs_dict(config, input_ids) + return config, inputs_dict + + +def prepare_bart_inputs_dict( + config, input_ids, attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + +@require_torch +class BARTModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (BartModel, BartForMaskedLM, BartForSequenceClassification) if is_torch_available() else () + is_encoder_decoder = True + # TODO(SS): fix the below in a separate PR + test_pruning = False + test_torchscript = False + test_head_masking = False + test_resize_embeddings = False # This requires inputs_dict['input_ids'] + + def setUp(self): + self.model_tester = ModelTester(self) + self.config_tester = ConfigTester(self, config_class=BartConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_advanced_inputs(self): + # (config, input_ids, token_type_ids, input_mask, *unused) = \ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"]) + model = BartModel(config) + model.to(torch_device) + model.eval() + # test init + self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item()) + + def _check_var(module): + """Check that we initialized various parameters from N(0, config.init_std).""" + self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2) + + _check_var(model.encoder.embed_tokens) + _check_var(model.encoder.layers[0].self_attn.k_proj) + _check_var(model.encoder.layers[0].fc1) + _check_var(model.encoder.embed_positions) + + decoder_features_with_created_mask = model.forward(**inputs_dict)[0] + decoder_features_with_passed_mask = model.forward( + decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict + )[0] + _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask) + useless_mask = torch.zeros_like(decoder_attn_mask) + decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0] + self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions + self.assertEqual( + decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model) + ) + if decoder_attn_mask.min().item() < -1e3: # some tokens were masked + self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item()) + + # Test different encoder attention masks + decoder_features_with_long_encoder_mask = model.forward( + inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long() + )[0] + _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask) + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + @unittest.skip("Passing inputs_embeds not implemented for Bart.") + def test_inputs_embeds(self): + pass + + +@require_torch +class BartHeadTests(unittest.TestCase): + + vocab_size = 99 + + def test_lm_forward(self): + input_ids = torch.Tensor( + [ + [71, 82, 18, 33, 46, 91, 2], + [68, 34, 26, 58, 30, 82, 2], + [5, 97, 17, 39, 94, 40, 2], + [76, 83, 94, 25, 70, 78, 2], + [87, 59, 41, 35, 48, 66, 2], + [55, 13, 16, 58, 5, 2, 1], # note padding + [64, 27, 31, 51, 12, 75, 2], + [52, 64, 86, 17, 83, 39, 2], + [48, 61, 9, 24, 71, 82, 2], + [26, 1, 60, 48, 22, 13, 2], + [21, 5, 62, 28, 14, 76, 2], + [45, 98, 37, 86, 59, 48, 2], + [70, 70, 50, 9, 28, 0, 2], + ] + ).long() + batch_size = input_ids.shape[0] + decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size) + + config = BartConfig( + vocab_size=self.vocab_size, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + ) + model = BartForSequenceClassification(config) + outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids) + logits = outputs[0] + expected_shape = torch.Size((batch_size, config.num_labels)) + self.assertEqual(logits.shape, expected_shape) + + lm_model = BartForMaskedLM(config) + loss, logits, enc_features = lm_model.forward( + input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids + ) + expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) + self.assertEqual(logits.shape, expected_shape) + self.assertIsInstance(loss.item(), float) + + def test_lm_uneven_forward(self): + config = BartConfig( + vocab_size=self.vocab_size, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + ) + lm_model = BartForMaskedLM(config) + context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long() + summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long() + logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary) + expected_shape = (*summary.shape, config.vocab_size) + self.assertEqual(logits.shape, expected_shape) + + def test_generate(self): + input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long() + config = BartConfig( + vocab_size=self.vocab_size, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + output_past=True, + ) + lm_model = BartForMaskedLM(config) + new_input_ids = lm_model.generate(input_ids) + self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20)) + + def test_shift_tokens_right(self): + input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long() + shifted = shift_tokens_right(input_ids, 1) + n_pad_before = input_ids.eq(1).float().sum() + n_pad_after = shifted.eq(1).float().sum() + self.assertEqual(shifted.shape, input_ids.shape) + self.assertEqual(n_pad_after, n_pad_before - 1) + self.assertTrue(torch.eq(shifted[:, 0], 2).all()) + + @slow + def test_tokenization(self): + tokenizer = BartTokenizer.from_pretrained("bart-large") + examples = [" Hello world", " DomDramg"] # need leading spaces for equality + fairseq_results = [ + torch.Tensor([0, 20920, 232, 2]), + torch.Tensor([0, 11349, 495, 4040, 571, 2]), + ] + for ex, desired_result in zip(examples, fairseq_results): + bart_toks = tokenizer.encode(ex, return_tensors="pt") + _assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex) + + +def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): + """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + msg = "{} != {}".format(a, b) + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +TOLERANCE = 1e-4 + + +@require_torch +class BartModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head(self): + model = BartModel.from_pretrained("bart-large") + input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]).long() + inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) + with torch.no_grad(): + output = model.forward(**inputs_dict)[0] + expected_shape = torch.Size((1, 11, 1024)) + self.assertEqual(output.shape, expected_shape) + expected_slice = torch.Tensor( + [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) + + @slow + def test_mnli_inference(self): + + example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1] + input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b]).long() + + model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli") # eval called in from_pre + inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) + # Test that model hasn't changed + with torch.no_grad(): + batched_logits, features = model.forward(**inputs_dict) + expected_shape = torch.Size((2, 3)) + self.assertEqual(batched_logits.shape, expected_shape) + expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]) + logits_arr = batched_logits[0].detach() + + # Test that padding does not change results + input_ids_no_pad = torch.Tensor([example_b[:-1]]).long() + + inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad) + with torch.no_grad(): + logits2 = model.forward(**inputs_dict)[0] + _assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE) + _assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE) + + @unittest.skip("This is just too slow") + def test_model_from_pretrained(self): + # Forces 1.6GB download from S3 for each model + for model_name in list(BART_PRETRAINED_MODEL_ARCHIVE_MAP.keys()): + model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR) + self.assertIsNotNone(model) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a5003519580..c575867af56 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -142,10 +142,17 @@ class ModelTesterMixin: out_len = len(outputs) if self.is_encoder_decoder: - self.assertEqual(out_len % 2, 0) - decoder_attentions = outputs[(out_len // 2) - 1] - self.assertEqual(model.config.output_attentions, True) - self.assertEqual(model.config.output_hidden_states, False) + correct_outlen = ( + 4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions + ) + decoder_attention_idx = 1 + if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first + correct_outlen += 1 # compute loss + decoder_attention_idx += 1 + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs[decoder_attention_idx] + self.assertIsInstance(decoder_attentions, (list, tuple)) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(decoder_attentions[0].shape[-3:]), @@ -562,15 +569,16 @@ class ModelTesterMixin: # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if not self.is_encoder_decoder: input_ids = inputs_dict["input_ids"] del inputs_dict["input_ids"] else: encoder_input_ids = inputs_dict["encoder_input_ids"] - decoder_input_ids = inputs_dict["decoder_input_ids"] + decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids) del inputs_dict["encoder_input_ids"] - del inputs_dict["decoder_input_ids"] + inputs_dict.pop("decoder_input_ids", None) for model_class in self.all_model_classes: model = model_class(config) diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 28141781705..9ea25a186b1 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -34,6 +34,7 @@ if is_torch_available(): ) from transformers.modeling_roberta import RobertaEmbeddings, RobertaForMultipleChoice, RobertaForQuestionAnswering from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP + from transformers.modeling_utils import create_position_ids_from_input_ids @require_torch @@ -291,7 +292,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase): [[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]] ) - position_ids = model.create_position_ids_from_input_ids(input_ids) + position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx) self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 964d5d4afee..d62ba2bd796 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -164,7 +164,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): decoder_attention_mask=decoder_attention_mask, decoder_lm_labels=decoder_lm_labels, ) - loss, prediction_scores = outputs[0], outputs[1] + loss, prediction_scores, encoder_features = outputs + self.parent.assertEqual(len(outputs), 3) result = { "loss": loss, "prediction_scores": prediction_scores,