mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
New BartModel (#2745)
* Results same as fairseq * Wrote a ton of tests * Struggled with api signatures * added some docs
This commit is contained in:
parent
564fd75d65
commit
53ce3854a1
@ -99,4 +99,5 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
|
||||
model_doc/camembert
|
||||
model_doc/albert
|
||||
model_doc/xlmroberta
|
||||
model_doc/flaubert
|
||||
model_doc/flaubert
|
||||
model_doc/bart
|
||||
|
52
docs/source/model_doc/bart.rst
Normal file
52
docs/source/model_doc/bart.rst
Normal file
@ -0,0 +1,52 @@
|
||||
Bart
|
||||
----------------------------------------------------
|
||||
**DISCLAIMER:** This model is still a work in progress, if you see something strange,
|
||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
|
||||
@sshleifer
|
||||
|
||||
The Bart model was `proposed <https://arxiv.org/abs/1910.13461>`_ by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov, Luke Zettlemoyer on 29 Oct, 2019.
|
||||
It is a sequence to sequence model where both encoder and decoder are transformers. The paper also introduces a novel pretraining objective, and demonstrates excellent summarization results.
|
||||
The authors released their code `here <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_
|
||||
|
||||
**Abstract:**
|
||||
|
||||
*We present BART, a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text. It uses a standard Tranformer-based neural machine translation architecture which, despite its simplicity, can be seen as generalizing BERT (due to the bidirectional encoder), GPT (with the left-to-right decoder), and many other more recent pretraining schemes. We evaluate a number of noising approaches, finding the best performance by both randomly shuffling the order of the original sentences and using a novel in-filling scheme, where spans of text are replaced with a single mask token. BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains of up to 6 ROUGE. BART also provides a 1.1 BLEU increase over a back-translation system for machine translation, with only target language pretraining. We also report ablation experiments that replicate other pretraining schemes within the BART framework, to better measure which factors most influence end-task performance.*
|
||||
`BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension`
|
||||
|
||||
|
||||
Notes:
|
||||
- Bart doesn't use :obj:`token_type_ids`, for sequence classification just use BartTokenizer.encode to get the proper splitting.
|
||||
- Inputs to the decoder are created by BartModel.forward if they are not passed. This is different than some other model APIs.
|
||||
- Model predictions are intended to be identical to the original implementation. This only works, however, if the string you pass to fairseq.encode starts with a space.
|
||||
|
||||
BartModel
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartModel
|
||||
:members: forward
|
||||
|
||||
|
||||
BartForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForMaskedLM
|
||||
:members: forward
|
||||
|
||||
|
||||
BartForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
BartConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BartConfig
|
||||
:members:
|
||||
|
||||
Automatic Creation of Decoder Inputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
This is enabled by default
|
||||
|
||||
.. autofunction:: transformers.modeling_bart._prepare_bart_decoder_inputs
|
@ -275,6 +275,13 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
||||
| | | | FlauBERT large architecture |
|
||||
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Bart | ``bart-large`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters |
|
||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
||||
| | | | bart-large base architecture with a classification head |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
|
||||
.. <https://huggingface.co/transformers/examples.html>`__
|
||||
|
@ -303,7 +303,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
mask = self._get_attn_subsequent_mask(MAX_SIZE)
|
||||
# Register self.mask as a buffer in TransformerDecoderLayer, so
|
||||
# Register self.mask as a saved_state in TransformerDecoderLayer, so
|
||||
# it gets TransformerDecoderLayer's cuda behavior automatically.
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
|
@ -21,6 +21,7 @@ import logging
|
||||
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@ -106,6 +107,7 @@ from .pipelines import (
|
||||
)
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .tokenization_bart import BartTokenizer
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@ -204,6 +206,7 @@ if is_torch_available():
|
||||
XLMForQuestionAnsweringSimple,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_bart import BartForSequenceClassification, BartModel, BartForMaskedLM
|
||||
from .modeling_roberta import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaModel,
|
||||
|
@ -19,6 +19,7 @@ import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
@ -42,6 +43,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -67,6 +69,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("albert", AlbertConfig,),
|
||||
("camembert", CamembertConfig,),
|
||||
("xlm-roberta", XLMRobertaConfig,),
|
||||
("bart", BartConfig,),
|
||||
("roberta", RobertaConfig,),
|
||||
("flaubert", FlaubertConfig,),
|
||||
("bert", BertConfig,),
|
||||
|
101
src/transformers/configuration_bart.py
Normal file
101
src/transformers/configuration_bart.py
Normal file
@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BART configuration """
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_bart_large_url = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json"
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large": _bart_large_url,
|
||||
"bart-large-mnli": _bart_large_url, # fine as same
|
||||
"bart-cnn": None, # not done
|
||||
}
|
||||
|
||||
|
||||
class BartConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration class for Bart. Parameters are renamed from the fairseq implementation
|
||||
"""
|
||||
model_type = "bart"
|
||||
pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
vocab_size=50265,
|
||||
pad_token_id=1,
|
||||
eos_token_id=2,
|
||||
d_model=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_layers=12,
|
||||
encoder_attention_heads=16,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_layers=12,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
attention_dropout=0.0,
|
||||
dropout=0.1,
|
||||
max_position_embeddings=1024,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
output_past=False,
|
||||
num_labels=3,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
:class:`~transformers.BartConfig` is the configuration class for `BartModel`.
|
||||
Examples:
|
||||
config = BartConfig.from_pretrained('bart-large')
|
||||
model = BartModel(config)
|
||||
"""
|
||||
super().__init__(num_labels=num_labels, output_past=output_past, pad_token_id=pad_token_id, **common_kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = self.num_hidden_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.init_std = init_std # Normal(0, this parameter)
|
||||
|
||||
# 3 Types of Dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.dropout = dropout
|
||||
|
||||
# Classifier stuff
|
||||
self.classif_dropout = classifier_dropout
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.d_model
|
@ -0,0 +1,100 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BART checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from transformers import BartConfig, BartForSequenceClassification, BartModel, BartTokenizer
|
||||
|
||||
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
IGNORE_KEYS = ["encoder.version", "decoder.version", "model.encoder.version", "model.decoder.version"]
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
b2 = torch.hub.load("pytorch/fairseq", checkpoint_path)
|
||||
b2.eval() # disable dropout
|
||||
b2.model.upgrade_state_dict(b2.model.state_dict())
|
||||
config = BartConfig()
|
||||
tokens = b2.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained("bart-large").encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
assert torch.eq(tokens, tokens2).all()
|
||||
|
||||
# assert their_output.size() == (1, 11, 1024)
|
||||
|
||||
if checkpoint_path == "bart.large":
|
||||
state_dict = b2.model.state_dict()
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
model = BartModel(config)
|
||||
their_output = b2.extract_features(tokens)
|
||||
|
||||
else: # MNLI Case
|
||||
state_dict = b2.state_dict()
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
state_dict.pop("_float_tensor", None)
|
||||
model = BartForSequenceClassification(config)
|
||||
their_output = b2.predict("mnli", tokens, return_logits=True)
|
||||
for k in IGNORE_KEYS:
|
||||
state_dict.pop(k, None)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
our_outputs = model.forward(tokens)[0]
|
||||
|
||||
assert their_output.shape == our_outputs.shape
|
||||
assert (their_output == our_outputs).all().item()
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("fairseq_path", choices=["bart.large", "bart.large.mnli"], type=str, help="")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(
|
||||
args.fairseq_path, args.pytorch_dump_folder_path,
|
||||
)
|
@ -21,6 +21,7 @@ from collections import OrderedDict
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
@ -43,6 +44,7 @@ from .modeling_albert import (
|
||||
AlbertForSequenceClassification,
|
||||
AlbertModel,
|
||||
)
|
||||
from .modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP, BartForMaskedLM, BartForSequenceClassification, BartModel
|
||||
from .modeling_bert import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertForMaskedLM,
|
||||
@ -118,6 +120,7 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@ -142,6 +145,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(AlbertConfig, AlbertModel),
|
||||
(CamembertConfig, CamembertModel),
|
||||
(XLMRobertaConfig, XLMRobertaModel),
|
||||
(BartConfig, BartModel),
|
||||
(RobertaConfig, RobertaModel),
|
||||
(BertConfig, BertModel),
|
||||
(OpenAIGPTConfig, OpenAIGPTModel),
|
||||
@ -161,6 +165,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(AlbertConfig, AlbertForMaskedLM),
|
||||
(CamembertConfig, CamembertForMaskedLM),
|
||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||
(BartConfig, BartForMaskedLM),
|
||||
(RobertaConfig, RobertaForMaskedLM),
|
||||
(BertConfig, BertForPreTraining),
|
||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||
@ -180,6 +185,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(AlbertConfig, AlbertForMaskedLM),
|
||||
(CamembertConfig, CamembertForMaskedLM),
|
||||
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||
(BartConfig, BartForMaskedLM),
|
||||
(RobertaConfig, RobertaForMaskedLM),
|
||||
(BertConfig, BertForMaskedLM),
|
||||
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||
@ -198,6 +204,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(AlbertConfig, AlbertForSequenceClassification),
|
||||
(CamembertConfig, CamembertForSequenceClassification),
|
||||
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
|
||||
(BartConfig, BartForSequenceClassification),
|
||||
(RobertaConfig, RobertaForSequenceClassification),
|
||||
(BertConfig, BertForSequenceClassification),
|
||||
(XLNetConfig, XLNetForSequenceClassification),
|
||||
|
1028
src/transformers/modeling_bart.py
Normal file
1028
src/transformers/modeling_bart.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -236,42 +236,6 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
@staticmethod
|
||||
def prepare_model_kwargs(**kwargs):
|
||||
""" Prepare the encoder and decoder's keyword arguments.
|
||||
|
||||
Keyword arguments come in 3 flavors:
|
||||
- encoder-specific (prefixed by `encoder_`)
|
||||
- decoder-specific (prefixed by `decoder_`)
|
||||
- those that apply to the model as whole.
|
||||
|
||||
We let the specific kwargs override the common ones in case of
|
||||
conflict.
|
||||
"""
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
decoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs.update(
|
||||
{
|
||||
argument[len("encoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
)
|
||||
decoder_kwargs.update(
|
||||
{
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
)
|
||||
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
|
||||
return encoder_kwargs, decoder_kwargs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedEncoderDecoder):
|
||||
r"""
|
||||
|
@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
|
||||
from .modeling_utils import create_position_ids_from_input_ids
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -56,7 +57,7 @@ class RobertaEmbeddings(BertEmbeddings):
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids).to(input_ids.device)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
@ -64,18 +65,6 @@ class RobertaEmbeddings(BertEmbeddings):
|
||||
input_ids, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
def create_position_ids_from_input_ids(self, x):
|
||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
|
||||
:param torch.Tensor x:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
mask = x.ne(self.padding_idx).long()
|
||||
incremental_indicies = torch.cumsum(mask, dim=1) * mask
|
||||
return incremental_indicies + self.padding_idx
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
""" We are provided embeddings directly. We cannot infer which are padded so just generate
|
||||
sequential position ids.
|
||||
|
@ -1448,6 +1448,20 @@ class SequenceSummary(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
|
||||
:param torch.Tensor x:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indicies.long() + padding_idx
|
||||
|
||||
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
|
@ -21,6 +21,7 @@ from collections import OrderedDict
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
AutoConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
@ -37,6 +38,7 @@ from .configuration_auto import (
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
@ -63,6 +65,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(AlbertConfig, (AlbertTokenizer, None)),
|
||||
(CamembertConfig, (CamembertTokenizer, None)),
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||
(BartConfig, (BartTokenizer, None)),
|
||||
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
||||
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
||||
|
35
src/transformers/tokenization_bart.py
Normal file
35
src/transformers/tokenization_bart.py
Normal file
@ -0,0 +1,35 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
|
||||
|
||||
# vocab and merges same as roberta
|
||||
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_bart_models = [
|
||||
"bart-large",
|
||||
"bart-large-mnli",
|
||||
# "bart-large-cnn"
|
||||
]
|
||||
|
||||
|
||||
class BartTokenizer(RobertaTokenizer):
|
||||
# merges and vocab same as Roberta
|
||||
max_model_input_sizes = {m: 1024 for m in _all_bart_models}
|
||||
pretrained_vocab_files_map = {
|
||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
47
src/transformers/utils_encoder_decoder.py
Normal file
47
src/transformers/utils_encoder_decoder.py
Normal file
@ -0,0 +1,47 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Classes to support Encoder-Decoder architectures """
|
||||
|
||||
|
||||
def prepare_encoder_decoder_model_kwargs(**kwargs):
|
||||
""" Prepare the encoder and decoder's keyword arguments.
|
||||
|
||||
Keyword arguments come in 3 flavors:
|
||||
- encoder-specific (prefixed by `encoder_`)
|
||||
- decoder-specific (prefixed by `decoder_`)
|
||||
- those that apply to the model as whole.
|
||||
|
||||
We let the specific kwargs override the common ones in case of
|
||||
conflict.
|
||||
"""
|
||||
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
||||
}
|
||||
if "input_ids" in kwargs_common:
|
||||
kwargs["encoder_input_ids"] = kwargs_common.pop("input_ids")
|
||||
|
||||
decoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs = kwargs_common.copy()
|
||||
encoder_kwargs.update(
|
||||
{argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")}
|
||||
)
|
||||
decoder_kwargs.update(
|
||||
{argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")}
|
||||
)
|
||||
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
|
||||
return encoder_kwargs, decoder_kwargs
|
343
tests/test_modeling_bart.py
Normal file
343
tests/test_modeling_bart.py
Normal file
@ -0,0 +1,343 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Huggingface
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BartModel,
|
||||
BartForMaskedLM,
|
||||
BartForSequenceClassification,
|
||||
BartConfig,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
shift_tokens_right,
|
||||
_prepare_bart_decoder_inputs,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelTester:
|
||||
def __init__(
|
||||
self, parent,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = True
|
||||
self.use_labels = False
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 12
|
||||
torch.manual_seed(0)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3,)
|
||||
input_ids[:, -1] = 2 # Eos Token
|
||||
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
encoder_ffn_dim=self.intermediate_size,
|
||||
decoder_ffn_dim=self.intermediate_size,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
)
|
||||
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
def prepare_bart_inputs_dict(
|
||||
config, input_ids, attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BartModel, BartForMaskedLM, BartForSequenceClassification) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
# TODO(SS): fix the below in a separate PR
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = False # This requires inputs_dict['input_ids']
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_advanced_inputs(self):
|
||||
# (config, input_ids, token_type_ids, input_mask, *unused) = \
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(config, inputs_dict["input_ids"])
|
||||
model = BartModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
# test init
|
||||
self.assertTrue((model.encoder.embed_tokens.weight == model.shared.weight).all().item())
|
||||
|
||||
def _check_var(module):
|
||||
"""Check that we initialized various parameters from N(0, config.init_std)."""
|
||||
self.assertAlmostEqual(torch.std(module.weight).item(), config.init_std, 2)
|
||||
|
||||
_check_var(model.encoder.embed_tokens)
|
||||
_check_var(model.encoder.layers[0].self_attn.k_proj)
|
||||
_check_var(model.encoder.layers[0].fc1)
|
||||
_check_var(model.encoder.embed_positions)
|
||||
|
||||
decoder_features_with_created_mask = model.forward(**inputs_dict)[0]
|
||||
decoder_features_with_passed_mask = model.forward(
|
||||
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||
self.assertEqual(
|
||||
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
|
||||
)
|
||||
if decoder_attn_mask.min().item() < -1e3: # some tokens were masked
|
||||
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
|
||||
|
||||
# Test different encoder attention masks
|
||||
decoder_features_with_long_encoder_mask = model.forward(
|
||||
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
|
||||
def test_save_load_strict(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
@unittest.skip("Passing inputs_embeds not implemented for Bart.")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartHeadTests(unittest.TestCase):
|
||||
|
||||
vocab_size = 99
|
||||
|
||||
def test_lm_forward(self):
|
||||
input_ids = torch.Tensor(
|
||||
[
|
||||
[71, 82, 18, 33, 46, 91, 2],
|
||||
[68, 34, 26, 58, 30, 82, 2],
|
||||
[5, 97, 17, 39, 94, 40, 2],
|
||||
[76, 83, 94, 25, 70, 78, 2],
|
||||
[87, 59, 41, 35, 48, 66, 2],
|
||||
[55, 13, 16, 58, 5, 2, 1], # note padding
|
||||
[64, 27, 31, 51, 12, 75, 2],
|
||||
[52, 64, 86, 17, 83, 39, 2],
|
||||
[48, 61, 9, 24, 71, 82, 2],
|
||||
[26, 1, 60, 48, 22, 13, 2],
|
||||
[21, 5, 62, 28, 14, 76, 2],
|
||||
[45, 98, 37, 86, 59, 48, 2],
|
||||
[70, 70, 50, 9, 28, 0, 2],
|
||||
]
|
||||
).long()
|
||||
batch_size = input_ids.shape[0]
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
)
|
||||
model = BartForSequenceClassification(config)
|
||||
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||
logits = outputs[0]
|
||||
expected_shape = torch.Size((batch_size, config.num_labels))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
lm_model = BartForMaskedLM(config)
|
||||
loss, logits, enc_features = lm_model.forward(
|
||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||
)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
def test_lm_uneven_forward(self):
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
|
||||
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
def test_generate(self):
|
||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
encoder_attention_heads=2,
|
||||
decoder_attention_heads=2,
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=True,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
new_input_ids = lm_model.generate(input_ids)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
|
||||
|
||||
def test_shift_tokens_right(self):
|
||||
input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long()
|
||||
shifted = shift_tokens_right(input_ids, 1)
|
||||
n_pad_before = input_ids.eq(1).float().sum()
|
||||
n_pad_after = shifted.eq(1).float().sum()
|
||||
self.assertEqual(shifted.shape, input_ids.shape)
|
||||
self.assertEqual(n_pad_after, n_pad_before - 1)
|
||||
self.assertTrue(torch.eq(shifted[:, 0], 2).all())
|
||||
|
||||
@slow
|
||||
def test_tokenization(self):
|
||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||
examples = [" Hello world", " DomDramg"] # need leading spaces for equality
|
||||
fairseq_results = [
|
||||
torch.Tensor([0, 20920, 232, 2]),
|
||||
torch.Tensor([0, 11349, 495, 4040, 571, 2]),
|
||||
]
|
||||
for ex, desired_result in zip(examples, fairseq_results):
|
||||
bart_toks = tokenizer.encode(ex, return_tensors="pt")
|
||||
_assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex)
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
if a is None and b is None:
|
||||
return True
|
||||
try:
|
||||
if torch.allclose(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
msg = "{} != {}".format(a, b)
|
||||
if prefix:
|
||||
msg = prefix + ": " + msg
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
@require_torch
|
||||
class BartModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = BartModel.from_pretrained("bart-large")
|
||||
input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]).long()
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||
with torch.no_grad():
|
||||
output = model.forward(**inputs_dict)[0]
|
||||
expected_shape = torch.Size((1, 11, 1024))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.Tensor(
|
||||
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
@slow
|
||||
def test_mnli_inference(self):
|
||||
|
||||
example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1]
|
||||
input_ids = torch.Tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b]).long()
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli") # eval called in from_pre
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||
# Test that model hasn't changed
|
||||
with torch.no_grad():
|
||||
batched_logits, features = model.forward(**inputs_dict)
|
||||
expected_shape = torch.Size((2, 3))
|
||||
self.assertEqual(batched_logits.shape, expected_shape)
|
||||
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]])
|
||||
logits_arr = batched_logits[0].detach()
|
||||
|
||||
# Test that padding does not change results
|
||||
input_ids_no_pad = torch.Tensor([example_b[:-1]]).long()
|
||||
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
||||
with torch.no_grad():
|
||||
logits2 = model.forward(**inputs_dict)[0]
|
||||
_assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE)
|
||||
_assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)
|
||||
|
||||
@unittest.skip("This is just too slow")
|
||||
def test_model_from_pretrained(self):
|
||||
# Forces 1.6GB download from S3 for each model
|
||||
for model_name in list(BART_PRETRAINED_MODEL_ARCHIVE_MAP.keys()):
|
||||
model = BartModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
@ -142,10 +142,17 @@ class ModelTesterMixin:
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
decoder_attentions = outputs[(out_len // 2) - 1]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
correct_outlen = (
|
||||
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
|
||||
)
|
||||
decoder_attention_idx = 1
|
||||
if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
decoder_attention_idx += 1
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
decoder_attentions = outputs[decoder_attention_idx]
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
@ -562,15 +569,16 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["encoder_input_ids"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs_dict["encoder_input_ids"]
|
||||
del inputs_dict["decoder_input_ids"]
|
||||
inputs_dict.pop("decoder_input_ids", None)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
@ -34,6 +34,7 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.modeling_roberta import RobertaEmbeddings, RobertaForMultipleChoice, RobertaForQuestionAnswering
|
||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_utils import create_position_ids_from_input_ids
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -291,7 +292,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
[[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
|
||||
)
|
||||
|
||||
position_ids = model.create_position_ids_from_input_ids(input_ids)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
|
@ -164,7 +164,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_lm_labels=decoder_lm_labels,
|
||||
)
|
||||
loss, prediction_scores = outputs[0], outputs[1]
|
||||
loss, prediction_scores, encoder_features = outputs
|
||||
self.parent.assertEqual(len(outputs), 3)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
|
Loading…
Reference in New Issue
Block a user