mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Blenderbot (#7418)
Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
aee7967fc4
commit
960faaaf28
@ -233,6 +233,7 @@ conversion utilities for the following models:
|
||||
model_doc/bart
|
||||
model_doc/bert
|
||||
model_doc/bertgeneration
|
||||
model_doc/blenderbot
|
||||
model_doc/camembert
|
||||
model_doc/ctrl
|
||||
model_doc/deberta
|
||||
|
75
docs/source/model_doc/blenderbot.rst
Normal file
75
docs/source/model_doc/blenderbot.rst
Normal file
@ -0,0 +1,75 @@
|
||||
Blenderbot
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
**DISCLAIMER:** If you see something strange,
|
||||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ .
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The Blender chatbot model was proposed in `Recipes for building an open-domain chatbot <https://arxiv.org/pdf/2004.13637.pdf>`__ Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston on 30 Apr 2020.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Building open-domain chatbots is a challenging area for machine learning research. While prior work has shown that scaling neural models in the number of parameters and the size of the data they are trained on gives improved results, we show that other ingredients are important for a high-performing chatbot. Good conversation requires a number of skills that an expert conversationalist blends in a seamless way: providing engaging talking points and listening to their partners, and displaying knowledge, empathy and personality appropriately, while maintaining a consistent persona. We show that large scale models can learn these skills when given appropriate training data and choice of generation strategy. We build variants of these recipes with 90M, 2.7B and 9.4B parameter models, and make our models and code publicly available. Human evaluations show our best models are superior to existing approaches in multi-turn dialogue in terms of engagingness and humanness measurements. We then discuss the limitations of this work by analyzing failure cases of our models.*
|
||||
|
||||
The authors' code can be found `here <https://github.com/facebookresearch/ParlAI>`__ .
|
||||
|
||||
|
||||
Implementation Notes
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Blenderbot uses a standard `seq2seq model transformer <https://arxiv.org/pdf/1706.03762.pdf>`__ based architecture.
|
||||
- It inherits completely from :class:`~transformers.BartForConditionalGeneration`
|
||||
- Even though blenderbot is one model, it uses two tokenizers :class:`~transformers.BlenderbotSmallTokenizer` for 90M checkpoint and :class:`~transformers.BlenderbotTokenizer` for all other checkpoints.
|
||||
- :class:`~transformers.BlenderbotSmallTokenizer` will always return :class:`~transformers.BlenderbotSmallTokenizer`, regardless of checkpoint. To use the 3B parameter checkpoint, you must call :class:`~transformers.BlenderbotTokenizer` directly.
|
||||
- Available checkpoints can be found in the `model hub <https://huggingface.co/models?search=blenderbot>`__.
|
||||
|
||||
|
||||
Usage
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Model Usage:
|
||||
|
||||
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotForConditionalGeneration
|
||||
>>> mname = 'facebook/blenderbot-90M'
|
||||
>>> model = BlenderbotForConditionalGeneration.from_pretrained(mname)
|
||||
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained(mname)
|
||||
>>> UTTERANCE = "My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer([UTTERANCE], return_tensors='pt')
|
||||
>>> reply_ids = model.generate(**inputs)
|
||||
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in reply_ids])
|
||||
|
||||
|
||||
See Config Values:
|
||||
|
||||
>>> from transformers import BlenderbotConfig
|
||||
>>> config_90 = BlenderbotConfig.from_pretrained("facebook/blenderbot-90M")
|
||||
>>> config_90.to_diff_dict() # show interesting Values.
|
||||
>>> configuration_3B = BlenderbotConfig("facebook/blenderbot-3B")
|
||||
>>> configuration_3B.to_diff_dict()
|
||||
|
||||
|
||||
BlenderbotConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. autoclass:: transformers.BlenderbotConfig
|
||||
:members:
|
||||
|
||||
BlenderbotTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotTokenizer
|
||||
:members: build_inputs_with_special_tokens
|
||||
|
||||
BlenderbotSmallTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotSmallTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
BlenderbotForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` and `generate`
|
||||
|
||||
.. autoclass:: transformers.BlenderbotForConditionalGeneration
|
||||
:members:
|
@ -33,6 +33,7 @@ from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPIN
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
|
||||
@ -158,6 +159,7 @@ from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast,
|
||||
from .tokenization_bert_generation import BertGenerationTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_bertweet import BertweetTokenizer
|
||||
from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_deberta import DebertaTokenizer
|
||||
@ -309,6 +311,7 @@ if is_torch_available():
|
||||
BertGenerationEncoder,
|
||||
load_tf_weights_in_bert_generation,
|
||||
)
|
||||
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
|
||||
from .modeling_camembert import (
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CamembertForCausalLM,
|
||||
|
@ -21,6 +21,7 @@ from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertCo
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
|
||||
@ -58,6 +59,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -101,6 +103,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("marian", MarianConfig),
|
||||
("mbart", MBartConfig),
|
||||
("bart", BartConfig),
|
||||
("blenderbot", BlenderbotConfig),
|
||||
("reformer", ReformerConfig),
|
||||
("longformer", LongformerConfig),
|
||||
("roberta", RobertaConfig),
|
||||
@ -136,6 +139,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("camembert", "CamemBERT"),
|
||||
("xlm-roberta", "XLM-RoBERTa"),
|
||||
("pegasus", "Pegasus"),
|
||||
("blenderbot", "Blenderbot"),
|
||||
("marian", "Marian"),
|
||||
("mbart", "mBART"),
|
||||
("bart", "BART"),
|
||||
|
@ -84,6 +84,8 @@ class BartConfig(PretrainedConfig):
|
||||
Don't learn positional embeddings, use sinusoidal.
|
||||
add_final_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Why not add another layernorm?
|
||||
do_blenderbot_90_layernorm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Blenderbot-90m checkpoint uses `layernorm_embedding` one line earlier in the decoder.
|
||||
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
eos_token_id (:obj:`int`, `optional`, defaults to 2)
|
||||
@ -113,7 +115,7 @@ class BartConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
extra_pos_embeddings=2, # FIXME(@sshleifer): delete?
|
||||
extra_pos_embeddings=2,
|
||||
activation_function="gelu",
|
||||
vocab_size=50265,
|
||||
d_model=1024,
|
||||
@ -137,6 +139,7 @@ class BartConfig(PretrainedConfig):
|
||||
eos_token_id=2,
|
||||
normalize_before=False,
|
||||
add_final_layer_norm=False,
|
||||
do_blenderbot_90_layernorm=False,
|
||||
scale_embedding=False,
|
||||
normalize_embedding=True,
|
||||
static_position_embeddings=False,
|
||||
@ -198,10 +201,13 @@ class BartConfig(PretrainedConfig):
|
||||
self.classifier_dropout = classifier_dropout
|
||||
|
||||
# pos embedding offset
|
||||
self.extra_pos_embeddings = self.pad_token_id + 1
|
||||
self.extra_pos_embeddings = extra_pos_embeddings
|
||||
# bart has a hack that offsets positional embeddings by 2, other models don't do this
|
||||
|
||||
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated
|
||||
|
||||
self.do_blenderbot_90_layernorm = do_blenderbot_90_layernorm
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
178
src/transformers/configuration_blenderbot.py
Normal file
178
src/transformers/configuration_blenderbot.py
Normal file
@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and Huggingface, 2020
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""BlenderbotConfig has the same signature as BartConfig. We only rewrite the signature in order to document blenderbot-90M defaults."""
|
||||
from .configuration_bart import BartConfig
|
||||
|
||||
|
||||
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/blenderbot-3B": "https://cdn.huggingface.co/facebook/blenderbot-3B/config.json",
|
||||
"facebook/blenderbot-90M": "https://cdn.huggingface.co/facebook/blenderbot-90M/config.json",
|
||||
}
|
||||
|
||||
|
||||
class BlenderbotConfig(BartConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.BlenderbotForConditionalGeneration`.
|
||||
It inherits from :class:`~transformers.BartConfig` and has the same signature with different defaults.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 54944):
|
||||
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.BlenderbotForConditionalGeneration`.
|
||||
d_model (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, `optional`, defaults to 8):
|
||||
Number of encoder layers, 6 are used for the `blenderbot-90M` model.
|
||||
decoder_layers (:obj:`int`, `optional`, defaults to 8):
|
||||
Number of decoder layers, 6 are used for the `blenderbot-90M` model.
|
||||
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
|
||||
dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
add_bias_logits (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
This should be completed, specific to marian.
|
||||
normalize_before (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Call layernorm before attention ops.
|
||||
normalize_embedding (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Call layernorm after embeddings.
|
||||
static_position_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Don't learn positional embeddings, use sinusoidal.
|
||||
add_final_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Why not add another layernorm?
|
||||
do_blenderbot_90_layernorm (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Blenderbot-90m checkpoint uses `layernorm_embedding` one line earlier in the decoder.
|
||||
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
eos_token_id (:obj:`int`, `optional`, defaults to 2)
|
||||
End of stream token id.
|
||||
pad_token_id (:obj:`int`, `optional`, defaults to 1)
|
||||
Padding token id.
|
||||
bos_token_id (:obj:`int`, `optional`, defaults to 0)
|
||||
Beginning of stream token id.
|
||||
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the `LayerDrop paper
|
||||
<see https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the `LayerDrop paper
|
||||
<see https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
extra_pos_embeddings: (:obj:`int`, `optional`, defaults to 2):
|
||||
How many extra learned positional embeddings to use. Should be set to :obj:`pad_token_id+1`.
|
||||
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether this is an encoder/decoder model.
|
||||
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``),
|
||||
"""
|
||||
model_type = "blenderbot"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
extra_pos_embeddings=0,
|
||||
activation_function="gelu",
|
||||
vocab_size=54944,
|
||||
d_model=512,
|
||||
encoder_ffn_dim=2048,
|
||||
encoder_layers=8,
|
||||
encoder_attention_heads=16,
|
||||
decoder_ffn_dim=2048,
|
||||
decoder_layers=8,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
attention_dropout=0.0,
|
||||
dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
classifier_dropout=0.0,
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
normalize_before=False,
|
||||
add_final_layer_norm=False,
|
||||
do_blenderbot_90_layernorm=True,
|
||||
scale_embedding=False,
|
||||
normalize_embedding=True,
|
||||
static_position_embeddings=False,
|
||||
add_bias_logits=False,
|
||||
force_bos_token_to_be_generated=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
Examples::
|
||||
|
||||
>>> from transformers import BlenderbotConfig
|
||||
>>> config = BlenderbotConfig.from_pretrained('facebook/blenderbot-90M')
|
||||
|
||||
"""
|
||||
if "hidden_size" in common_kwargs:
|
||||
raise ValueError("hidden size is called d_model")
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
encoder_ffn_dim=encoder_ffn_dim,
|
||||
encoder_layers=encoder_layers,
|
||||
encoder_layerdrop=encoder_layerdrop,
|
||||
encoder_attention_heads=encoder_attention_heads,
|
||||
decoder_layerdrop=decoder_layerdrop,
|
||||
decoder_ffn_dim=decoder_ffn_dim,
|
||||
decoder_layers=decoder_layers,
|
||||
normalize_before=normalize_before,
|
||||
normalize_embedding=normalize_embedding,
|
||||
static_position_embeddings=static_position_embeddings,
|
||||
add_bias_logits=add_bias_logits,
|
||||
force_bos_token_to_be_generated=force_bos_token_to_be_generated,
|
||||
do_blenderbot_90_layernorm=do_blenderbot_90_layernorm,
|
||||
add_final_layer_norm=add_final_layer_norm,
|
||||
scale_embedding=scale_embedding,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
classifier_dropout=classifier_dropout,
|
||||
activation_dropout=activation_dropout,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
extra_pos_embeddings=extra_pos_embeddings,
|
||||
activation_function=activation_function,
|
||||
decoder_attention_heads=decoder_attention_heads,
|
||||
**common_kwargs,
|
||||
)
|
@ -0,0 +1,114 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Blenderbot checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BartConfig, BartForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PATTERNS = [
|
||||
["attention", "attn"],
|
||||
["encoder_attention", "encoder_attn"],
|
||||
["q_lin", "q_proj"],
|
||||
["k_lin", "k_proj"],
|
||||
["v_lin", "v_proj"],
|
||||
["out_lin", "out_proj"],
|
||||
["norm_embeddings", "layernorm_embedding"],
|
||||
["position_embeddings", "embed_positions"],
|
||||
["embeddings", "embed_tokens"],
|
||||
["ffn.lin", "fc"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
if k == "embeddings.weight":
|
||||
return "shared.weight"
|
||||
|
||||
for parlai_name, hf_name in PATTERNS:
|
||||
k = k.replace(parlai_name, hf_name)
|
||||
|
||||
if k.startswith("encoder"):
|
||||
k = k.replace(".attn", ".self_attn")
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "final_layer_norm")
|
||||
elif k.startswith("decoder"):
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "encoder_attn_layer_norm")
|
||||
k = k.replace("norm3", "final_layer_norm")
|
||||
return k
|
||||
|
||||
|
||||
def rename_layernorm_keys(sd):
|
||||
keys = [
|
||||
"model.encoder.layernorm_embedding.weight",
|
||||
"model.encoder.layernorm_embedding.bias",
|
||||
"model.decoder.layernorm_embedding.weight",
|
||||
"model.decoder.layernorm_embedding.bias",
|
||||
]
|
||||
for k in keys:
|
||||
v = sd.pop(k)
|
||||
new_k = k.replace("layernorm_embedding", "layer_norm")
|
||||
assert new_k not in sd
|
||||
sd[new_k] = v
|
||||
|
||||
|
||||
IGNORE_KEYS = ["START"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu")
|
||||
sd = model["model"]
|
||||
cfg = BartConfig.from_json_file(config_json_path)
|
||||
m = BartForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
for k, v in sd.items():
|
||||
if k in IGNORE_KEYS:
|
||||
continue
|
||||
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in valid_keys:
|
||||
failures.append([k, new_k])
|
||||
else:
|
||||
mapping[new_k] = v
|
||||
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
|
||||
rename_layernorm_keys(sd)
|
||||
m.model.load_state_dict(mapping, strict=True)
|
||||
m.half()
|
||||
m.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
|
||||
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
|
||||
parser.add_argument(
|
||||
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
|
@ -24,6 +24,7 @@ from .configuration_auto import (
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
BlenderbotConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DebertaConfig,
|
||||
@ -82,6 +83,7 @@ from .modeling_bert import (
|
||||
BertModel,
|
||||
)
|
||||
from .modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
|
||||
from .modeling_blenderbot import BlenderbotForConditionalGeneration
|
||||
from .modeling_camembert import (
|
||||
CamembertForCausalLM,
|
||||
CamembertForMaskedLM,
|
||||
@ -353,6 +355,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
(PegasusConfig, PegasusForConditionalGeneration),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(MBartConfig, MBartForConditionalGeneration),
|
||||
(BlenderbotConfig, BlenderbotForConditionalGeneration),
|
||||
(BartConfig, BartForConditionalGeneration),
|
||||
(FSMTConfig, FSMTForConditionalGeneration),
|
||||
(EncoderDecoderConfig, EncoderDecoderModel),
|
||||
|
@ -499,6 +499,7 @@ class BartDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm # layernorm variant
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
@ -578,8 +579,13 @@ class BartDecoder(nn.Module):
|
||||
positions = positions[:, -1:]
|
||||
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
x = self.layernorm_embedding(x)
|
||||
if self.do_blenderbot_90_layernorm:
|
||||
x = self.layernorm_embedding(x)
|
||||
x += positions
|
||||
else:
|
||||
x += positions
|
||||
x = self.layernorm_embedding(x)
|
||||
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
|
56
src/transformers/modeling_blenderbot.py
Normal file
56
src/transformers/modeling_blenderbot.py
Normal file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
""""BlenderbotForConditionalGeneration which inherits from BART"""
|
||||
|
||||
import torch
|
||||
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
BLENDER_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
"""
|
||||
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/blenderbot-90M"]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
|
||||
)
|
||||
class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the
|
||||
superclass for the appropriate documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = BlenderbotConfig
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.bos_token_id] = -torch.finfo(torch.float16).max # near infinity fp16
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
@ -27,5 +27,5 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
>>> translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
>>> assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
"""
|
||||
|
||||
model_type = "mbart"
|
||||
config_class = MBartConfig
|
||||
|
@ -23,6 +23,7 @@ from .configuration_auto import (
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
BlenderbotConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DebertaConfig,
|
||||
@ -61,6 +62,7 @@ from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_bert_generation import BertGenerationTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_bertweet import BertweetTokenizer
|
||||
from .tokenization_blenderbot import BlenderbotSmallTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
from .tokenization_deberta import DebertaTokenizer
|
||||
@ -108,6 +110,8 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(MBartConfig, (MBartTokenizer, None)),
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
|
||||
(MarianConfig, (MarianTokenizer, None)),
|
||||
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)),
|
||||
(LongformerConfig, (LongformerTokenizer, None)),
|
||||
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||
(RobertaConfig, (BertweetTokenizer, None)),
|
||||
|
271
src/transformers/tokenization_blenderbot.py
Normal file
271
src/transformers/tokenization_blenderbot.py
Normal file
@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
""""BlenderbotTokenizer and BlenderbotSmallTokenizer"""
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
# "tokenizer_config_file": "tokenizer_config.json",
|
||||
}
|
||||
CKPT_3B = "facebook/blenderbot-3B"
|
||||
|
||||
|
||||
class BlenderbotTokenizer(RobertaTokenizer):
|
||||
r"""
|
||||
Construct a Blenderbot tokenizer.
|
||||
|
||||
:class:`~transformers.Blenderbot` is nearly identical to :class:`~transformers.RobertaTokenizer` and runs
|
||||
end-to-end tokenization: punctuation splitting and wordpiece. The only difference is that it doesnt add BOS
|
||||
token to the beginning of sequences.
|
||||
|
||||
Refer to superclass :class:`~transformers.RobertaTokenizer` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
vocab_files_names = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
"tokenizer_config_file": "tokenizer_config.json",
|
||||
}
|
||||
pretrained_vocab_files_map = {
|
||||
"vocab_file": {CKPT_3B: "https://cdn.huggingface.co/facebook/blenderbot-3B/vocab.json"},
|
||||
"merges_file": {CKPT_3B: "https://cdn.huggingface.co/facebook/blenderbot-3B/merges.txt"},
|
||||
"tokenizer_config_file": {CKPT_3B: "https://cdn.huggingface.co/facebook/blenderbot-3B/tokenizer_config.json"},
|
||||
}
|
||||
max_model_input_sizes = {"facebook/blenderbot-3B": 128}
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: List[int] = None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A Blenderbot sequence has the following format:
|
||||
|
||||
- single sequence: `` X </s>``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Will be ignored
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
|
||||
pairs = set(pairs)
|
||||
return pairs
|
||||
|
||||
|
||||
class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Constructs a Blenderbot-90M tokenizer based on BPE (Byte-Pair-Encoding)
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users
|
||||
should refer to the superclass for more information regarding methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
File containing the vocabulary.
|
||||
merges_file (:obj:`str`):
|
||||
Path to the merges file.
|
||||
bos_token (:obj:`str`, `optional`, defaults to :obj:`"__start__"`):
|
||||
The beginning of sentence token.
|
||||
eos_token (:obj:`str`, `optional`, defaults to :obj:`"__end__"`):
|
||||
The end of sentence token.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"__unk__"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"__pad__"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
**kwargs
|
||||
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
|
||||
pretrained_vocab_files_map = {
|
||||
"vocab_file": {"facebook/blenderbot-90M": "https://cdn.huggingface.co/facebook/blenderbot-90M/vocab.json"},
|
||||
"merges_file": {"facebook/blenderbot-90M": "https://cdn.huggingface.co/facebook/blenderbot-90M/merges.txt"},
|
||||
}
|
||||
max_model_input_sizes = {"facebook/blenderbot-90M": 512}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
bos_token="__start__",
|
||||
eos_token="__end__",
|
||||
unk_token="__unk__",
|
||||
pad_token="__null",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, **kwargs)
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
with open(merges_file, encoding="utf-8") as merges_handle:
|
||||
merges = merges_handle.read().split("\n")[1:-1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.encoder)
|
||||
|
||||
def get_vocab(self) -> Dict:
|
||||
return dict(self.encoder, **self.added_tokens_encoder)
|
||||
|
||||
def bpe(self, token: str) -> str:
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
token = re.sub("([.,!?()])", r" \1", token)
|
||||
token = re.sub("(')", r" \1 ", token)
|
||||
token = re.sub("\s{2,}", " ", token)
|
||||
if "\n" in token:
|
||||
token = token.replace("\n", " __newln__")
|
||||
|
||||
tokens = token.split(" ")
|
||||
words = []
|
||||
for token in tokens:
|
||||
token = token.lower()
|
||||
word = tuple(token)
|
||||
word = tuple(list(word[:-1]) + [word[-1] + "</w>"])
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
words.append(token)
|
||||
continue
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = "@@ ".join(word)
|
||||
word = word[:-4]
|
||||
|
||||
self.cache[token] = word
|
||||
words.append(word)
|
||||
return " ".join(words)
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
""" Split a string into tokens using BPE."""
|
||||
split_tokens = []
|
||||
|
||||
words = re.findall(r"\S+\n?", text)
|
||||
|
||||
for token in words:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(" ")])
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
""" Converts a token to an id using the vocab. """
|
||||
token = token.lower()
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
""" Converts a sequence of tokens in a single string. """
|
||||
out_string = " ".join(tokens).replace("@@ ", "").strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
|
||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES["merges_file"])
|
||||
|
||||
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write("#version: 0.2\n")
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
"Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file)
|
||||
)
|
||||
index = token_index
|
||||
writer.write(" ".join(bpe_tokens) + "\n")
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
@ -449,6 +449,18 @@ def load_tf_weights_in_bert_generation(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_bert_generation)
|
||||
|
||||
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class BlenderbotForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
50
test_tokenization_blenderbot.py
Normal file
50
test_tokenization_blenderbot.py
Normal file
@ -0,0 +1,50 @@
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.tokenization_blenderbot import VOCAB_FILES_NAMES, BlenderbotTokenizer, BlenderbotSmallTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
class BlenderbotSmallTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = BlenderbotSmallTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = ["adapt", "react", "read@@", "ap@@", "t", "__unk__", "__start__", "__end__", "__null__"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "a p", "ap t</w>", "r e", "a d", "ad apt</w>", ""]
|
||||
self.special_tokens_map = {"bos_token": "__start", "eos_token": "__end__", "pad_token": "__null__", "unk_token": "__unk__"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return BlenderbotSmallTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "adapt react readapt apt"
|
||||
output_text = "adapt react readapt apt"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_blenderbot_small_tokenizer(self):
|
||||
tokenizer = BlenderbotSmallTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "adapt react readapt apt"
|
||||
bpe_tokens = ['adapt', 'react', 'read@@', 'ap@@', 't', 'ap@@', 't']
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = [tokenizer.bos_token] + tokens + [tokenizer.eos_token]
|
||||
print(input_tokens)
|
||||
|
||||
# input_bpe_tokens = [0, 1, 2, 4, 5, 1, 0, 3, 6]
|
||||
# self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
@ -212,8 +212,9 @@ class AutoModelTest(unittest.TestCase):
|
||||
mapping = tuple(mapping.items())
|
||||
for index, (child_config, child_model) in enumerate(mapping[1:]):
|
||||
for parent_config, parent_model in mapping[: index + 1]:
|
||||
with self.subTest(
|
||||
msg="Testing if {} is child of {}".format(child_config.__name__, parent_config.__name__)
|
||||
):
|
||||
self.assertFalse(issubclass(child_config, parent_config))
|
||||
self.assertFalse(issubclass(child_model, parent_model))
|
||||
assert not issubclass(
|
||||
child_config, parent_config
|
||||
), "{child_config.__name__} is child of {parent_config.__name__}"
|
||||
assert not issubclass(
|
||||
child_model, parent_model
|
||||
), "{child_config.__name__} is child of {parent_config.__name__}"
|
||||
|
@ -40,6 +40,11 @@ if is_torch_available():
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
BertConfig,
|
||||
BlenderbotConfig,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
PegasusConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
@ -175,7 +180,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_features_with_passed_mask = model(
|
||||
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
assert_tensors_close(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||
@ -189,7 +194,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_features_with_long_encoder_mask = model(
|
||||
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
assert_tensors_close(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
|
||||
def test_save_load_strict(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
@ -364,7 +369,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
]
|
||||
for ex, desired_result in zip(examples, fairseq_results):
|
||||
bart_toks = tokenizer.encode(ex, return_tensors="pt")
|
||||
_assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex)
|
||||
assert_tensors_close(desired_result.long(), bart_toks, prefix=ex)
|
||||
|
||||
def test_generate_fp16(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
@ -411,16 +416,22 @@ class BartHeadTests(unittest.TestCase):
|
||||
self.assertTrue(torch.eq(input_new, output_new).all())
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b aren't both tensors, raise a nice Assertion error."""
|
||||
|
||||
if a is None and b is None:
|
||||
return True
|
||||
assert a.shape == b.shape
|
||||
try:
|
||||
if torch.allclose(a, b, atol=atol):
|
||||
return True
|
||||
raise
|
||||
except Exception:
|
||||
msg = "{} != {}".format(a, b)
|
||||
pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
|
||||
if a.numel() > 100:
|
||||
msg = f"tensor values are {pct_different:.1%} percent different."
|
||||
else:
|
||||
msg = f"{a} != {b}"
|
||||
if prefix:
|
||||
msg = prefix + ": " + msg
|
||||
raise AssertionError(msg)
|
||||
@ -496,8 +507,8 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
||||
with torch.no_grad():
|
||||
logits2 = model(**inputs_dict)[0]
|
||||
_assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE)
|
||||
_assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)
|
||||
assert_tensors_close(batched_logits[1], logits2, atol=TOLERANCE)
|
||||
assert_tensors_close(expected_slice, logits_arr, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_xsum_summarization_same_as_fairseq(self):
|
||||
@ -633,3 +644,12 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
def test_child_config_equivalence(self):
|
||||
"""Test that configs associated with children of BartForConditionalGeneration are identical."""
|
||||
child_classes = [BlenderbotConfig, MBartConfig, MarianConfig, PegasusConfig]
|
||||
parent_keys = BartConfig().to_dict().keys()
|
||||
for c in child_classes:
|
||||
assert c().to_dict().keys() == parent_keys # traceback is very nice on it's own
|
||||
# check that test is not stupid
|
||||
assert BertConfig().to_dict().keys() != parent_keys
|
||||
|
215
tests/test_modeling_blenderbot.py
Normal file
215
tests/test_modeling_blenderbot.py
Normal file
@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Tests for BlenderBot"""
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BlenderbotConfig,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotSmallTokenizer,
|
||||
BlenderbotTokenizer,
|
||||
)
|
||||
|
||||
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BlenderbotModelTester:
|
||||
# Required attributes
|
||||
vocab_size = 99
|
||||
batch_size = 13
|
||||
seq_length = 7
|
||||
num_hidden_layers = 2
|
||||
hidden_size = 16
|
||||
num_attention_heads = 4
|
||||
is_training = True
|
||||
|
||||
def __init__(self, parent):
|
||||
torch.manual_seed(0)
|
||||
self.parent = parent
|
||||
self.config = BlenderbotConfig(
|
||||
d_model=self.hidden_size,
|
||||
dropout=0.0,
|
||||
activation_function="gelu",
|
||||
vocab_size=self.vocab_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
attention_dropout=0.0,
|
||||
encoder_ffn_dim=4,
|
||||
decoder_ffn_dim=4,
|
||||
do_blenderbot_90_layernorm=False,
|
||||
normalize_before=True,
|
||||
max_position_embeddings=50,
|
||||
static_position_embeddings=False,
|
||||
scale_embedding=True,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
num_beams=1,
|
||||
min_length=3,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return self.config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available():
|
||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,)
|
||||
all_model_classes = (BlenderbotForConditionalGeneration,)
|
||||
else:
|
||||
all_generative_model_classes = ()
|
||||
all_model_classes = ()
|
||||
is_encoder_decoder = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlenderbotModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_initialization_module(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = BlenderbotForConditionalGeneration(config).model
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
enc_embeds = model.encoder.embed_tokens.weight
|
||||
assert (enc_embeds == model.shared.weight).all().item()
|
||||
self.assertAlmostEqual(torch.std(enc_embeds).item(), config.init_std, 2)
|
||||
|
||||
def test_embed_pos_shape(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = BlenderbotForConditionalGeneration(config)
|
||||
expected_shape = (config.max_position_embeddings + config.extra_pos_embeddings, config.d_model)
|
||||
assert model.model.encoder.embed_positions.weight.shape == expected_shape
|
||||
model.model.decoder.embed_positions.weight.shape == expected_shape
|
||||
|
||||
@unittest.skip("This test is flaky")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
|
||||
@require_torch
|
||||
class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
ckpt = "facebook/blenderbot-3B"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return BlenderbotTokenizer.from_pretrained(self.ckpt)
|
||||
|
||||
@slow
|
||||
def test_generation_from_short_input_same_as_parlai_3B(self):
|
||||
|
||||
src_text = ["Sam"]
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs, **FASTER_GEN_KWARGS)
|
||||
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
|
||||
|
||||
generated_txt = self.tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
|
||||
assert generated_txt[0].strip() == tgt_text
|
||||
|
||||
@slow
|
||||
def test_generation_from_long_input_same_as_parlai_3B(self):
|
||||
|
||||
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||
|
||||
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
|
||||
generated_ids = self.model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
assert "I think it's because we are so worried about what people think of us." == reply.strip()
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
ckpt = "facebook/blenderbot-90M"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(self.ckpt).to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained(self.ckpt)
|
||||
|
||||
@slow
|
||||
def test_90_generation_from_long_input(self):
|
||||
|
||||
src_text = [
|
||||
"Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like\
|
||||
i'm going to throw up.\nand why is that?"
|
||||
]
|
||||
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
||||
assert self.model.config.do
|
||||
generated_ids = self.model.generate(**model_inputs)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
assert reply in (
|
||||
"i don't know. i just feel like i'm going to throw up. it's not fun.",
|
||||
"i'm not sure. i just feel like i've been feeling like i have to be in a certain place",
|
||||
)
|
||||
|
||||
def test_90_generation_from_short_input(self):
|
||||
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
# generated_txt = self.tokenizer.decode(generated_utterances[0])
|
||||
|
||||
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
|
||||
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
|
||||
assert clean_txt in (
|
||||
"have you ever been to a sam club? it's a great club in the south.",
|
||||
"have you ever heard of sam harris? he's an american singer, songwriter, and actor.",
|
||||
)
|
@ -752,6 +752,10 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
@ -765,7 +769,9 @@ class ModelTesterMixin:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.",
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bart import TOLERANCE, _assert_tensors_equal, _long_tensor
|
||||
from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -79,7 +79,17 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
expected_slice = torch.tensor([9.0078, 10.1113, 14.4787], device=logits.device, dtype=logits.dtype)
|
||||
result_slice = logits[0, 0, :3]
|
||||
_assert_tensors_equal(expected_slice, result_slice, atol=TOLERANCE)
|
||||
assert_tensors_close(expected_slice, result_slice, atol=TOLERANCE)
|
||||
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
["UN Chief Says There Is No Military Solution in Syria"]
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
# self.assertEqual(self.tgt_text[1], decoded[1])
|
||||
|
||||
@slow
|
||||
def test_enro_generate(self):
|
||||
|
93
tests/test_tokenization_blenderbot.py
Normal file
93
tests/test_tokenization_blenderbot.py
Normal file
@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the;
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Tests for Blenderbot Tokenizers, including common tests for BlenderbotSmallTokenizer."""
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.tokenization_blenderbot import VOCAB_FILES_NAMES, BlenderbotSmallTokenizer, BlenderbotTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class BlenderbotSmallTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = BlenderbotSmallTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = ["__start__", "adapt", "act", "ap@@", "te", "__end__", "__unk__"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
merges = ["#version: 0.2", "a p", "t e</w>", "ap t</w>", "a d", "ad apt</w>", "a c", "ac t</w>", ""]
|
||||
self.special_tokens_map = {"unk_token": "__unk__", "bos_token": "__start__", "eos_token": "__end__"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return BlenderbotSmallTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "adapt act apte"
|
||||
output_text = "adapt act apte"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_blenderbot_small_tokenizer(self):
|
||||
tokenizer = BlenderbotSmallTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "adapt act apte"
|
||||
bpe_tokens = ["adapt", "act", "ap@@", "te"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = [tokenizer.bos_token] + tokens + [tokenizer.eos_token]
|
||||
|
||||
input_bpe_tokens = [0, 1, 2, 3, 4, 5]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_special_tokens_small_tok(self):
|
||||
tok = BlenderbotSmallTokenizer.from_pretrained("facebook/blenderbot-90M")
|
||||
assert tok("sam").input_ids == [1384]
|
||||
src_text = "I am a small frog."
|
||||
encoded = tok([src_text], padding=False, truncation=False)["input_ids"]
|
||||
decoded = tok.batch_decode(encoded, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
assert src_text != decoded # I wish it did!
|
||||
assert decoded == "i am a small frog ."
|
||||
|
||||
|
||||
class Blenderbot3BTokenizerTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def tokenizer_3b(self):
|
||||
return BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B")
|
||||
|
||||
def test_encode_decode_cycle(self):
|
||||
tok = self.tokenizer_3b
|
||||
src_text = " I am a small frog."
|
||||
encoded = tok([src_text], padding=False, truncation=False)["input_ids"]
|
||||
decoded = tok.batch_decode(encoded, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
assert src_text == decoded
|
||||
|
||||
def test_3B_tokenization_same_as_parlai(self):
|
||||
assert self.tokenizer_3b.add_prefix_space
|
||||
assert self.tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
|
Loading…
Reference in New Issue
Block a user