Add Flaubert

This commit is contained in:
Hang Le 2020-01-15 17:22:25 +01:00 committed by Lysandre Debut
parent a5381495e6
commit f0a4fc6cd6
7 changed files with 761 additions and 3 deletions

View File

@ -160,8 +160,9 @@ At some point in the future, you'll be able to seamlessly move from pre-training
12. **[T5](https://github.com/google-research/text-to-text-transfer-transformer)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
13. **[XLM-RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/xlmr)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
14. **[MMBT](https://github.com/facebookresearch/mmbt/)** (from Facebook), released together with the paper a [Supervised Multimodal Bitransformers for Classifying Images and Text](https://arxiv.org/pdf/1909.02950.pdf) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
15. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
16. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
15. **[FlauBERT](https://github.com/getalp/Flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab.
16. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users).
17. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).

View File

@ -41,6 +41,9 @@ from transformers import (
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
FlaubertConfig,
FlaubertForSequenceClassification,
FlaubertTokenizer,
RobertaConfig,
RobertaForSequenceClassification,
RobertaTokenizer,
@ -80,6 +83,7 @@ ALL_MODELS = sum(
DistilBertConfig,
AlbertConfig,
XLMRobertaConfig,
FlaubertConfig,
)
),
(),
@ -93,6 +97,7 @@ MODEL_CLASSES = {
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
}
@ -480,7 +485,7 @@ def main():
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step.",
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
)
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",

View File

@ -25,6 +25,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_mmbt import MMBTConfig
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
@ -108,6 +109,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
@ -260,6 +262,15 @@ if is_torch_available():
)
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
from .modeling_flaubert import (
FlaubertModel,
FlaubertWithLMHeadModel,
FlaubertForSequenceClassification,
FlaubertForQuestionAnswering,
FlaubertForQuestionAnsweringSimple,
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
# Optimization
from .optimization import (
AdamW,

View File

@ -23,6 +23,7 @@ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
@ -53,6 +54,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
@ -73,6 +75,7 @@ CONFIG_MAPPING = OrderedDict(
("xlnet", XLNetConfig,),
("xlm", XLMConfig,),
("ctrl", CTRLConfig,),
("flaubert", FlaubertConfig,),
]
)
@ -126,6 +129,7 @@ class AutoConfig(object):
- contains `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model)
- contains `xlm`: :class:`~transformers.XLMConfig` (XLM model)
- contains `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
- contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
Args:

View File

@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flaubert configuration, based on XLM. """
import logging
from .configuration_xlm import XLMConfig
logger = logging.getLogger(__name__)
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/config.json",
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/config.json",
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/config.json",
}
class FlaubertConfig(XLMConfig):
"""Configuration class to store the configuration of a `FlaubertModel`.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `FlaubertModel`.
d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
d_inner: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
attn_type: 'bi' for Flaubert, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "flaubert"
def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs):
"""Constructs FlaubertConfig.
"""
super().__init__(**kwargs)
self.layerdrop = layerdrop
self.pre_norm = pre_norm

View File

@ -0,0 +1,510 @@
# coding=utf-8
# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Flaubert model, based on XLM. """
import logging
import torch
from torch.nn import functional as F
from .configuration_flaubert import FlaubertConfig
from .file_utils import add_start_docstrings
from .modeling_xlm import (
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
XLMModel,
XLMWithLMHeadModel,
get_masks,
)
logger = logging.getLogger(__name__)
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/pytorch_model.bin",
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/pytorch_model.bin",
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/pytorch_model.bin",
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/pytorch_model.bin",
}
FLAUBERT_START_DOCSTRING = r""" The Flaubert model was proposed in
`FlauBERT: Unsupervised Language Model Pre-training for French`_
by Hang Le et al. It's a transformer pre-trained using a masked
language modeling (MLM) objective (BERT-like).
Original code can be found `here`_.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matters related to general usage and behavior.
.. _`FlauBERT: Unsupervised Language Model Pre-training for French`:
https://arxiv.org/abs/1912.05372
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
.. _`here`:
https://github.com/getalp/Flaubert
Parameters:
config (:class:`~transformers.FlaubertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
FLAUBERT_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Flaubert is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`transformers.FlaubertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**lengths**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Length of each sentence that can be used to avoid performing attention on padding token indices.
You can also use `attention_mask` for the same result (see above), kept here for compatbility.
Indices selected in ``[0, ..., input_ids.size(-1)]``:
**cache**:
dictionary with ``torch.FloatTensor`` that contains pre-computed
hidden-states (key and values in the attention blocks) as computed by the model
(see `cache` output below). Can be used to speed up sequential decoding.
The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
"""
@add_start_docstrings(
"The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
FLAUBERT_START_DOCSTRING,
FLAUBERT_INPUTS_DOCSTRING,
)
class FlaubertModel(XLMModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-cased')
model = FlaubertModel.from_pretrained('flaubert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Le chat manges une pomme.", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
config_class = FlaubertConfig
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
model_type = "flaubert"
def __init__(self, config): # , dico, is_encoder, with_output):
super(FlaubertModel, self).__init__(config)
self.layerdrop = 0.0 if not hasattr(config, "layerdrop") else config.layerdrop
self.pre_norm = False if not hasattr(config, "pre_norm") else config.pre_norm
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
): # removed: src_enc=None, src_len=None
if input_ids is not None:
bs, slen = input_ids.size()
else:
bs, slen = inputs_embeds.size()[:-1]
if lengths is None:
if input_ids is not None:
lengths = (input_ids != self.pad_index).sum(dim=1).long()
else:
lengths = torch.LongTensor([slen] * bs)
# mask = input_ids != self.pad_index
# check inputs
assert lengths.size(0) == bs
assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
# if src_enc is not None:
# assert self.is_decoder
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
device = input_ids.device if input_ids is not None else inputs_embeds.device
# position_ids
if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else:
assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1)
# langs
if langs is not None:
assert langs.size() == (bs, slen) # (slen, bs)
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.n_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layers
# do not recompute cached elements
if cache is not None and input_ids is not None:
_slen = slen - cache["slen"]
input_ids = input_ids[:, -_slen:]
position_ids = position_ids[:, -_slen:]
if langs is not None:
langs = langs[:, -_slen:]
mask = mask[:, -_slen:]
attn_mask = attn_mask[:, -_slen:]
# embeddings
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
tensor = self.layer_norm_emb(tensor)
tensor = F.dropout(tensor, p=self.dropout, training=self.training)
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers
hidden_states = ()
attentions = ()
for i in range(self.n_layers):
if self.output_hidden_states:
hidden_states = hidden_states + (tensor,)
# self attention
if not self.pre_norm:
attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
attn = attn_outputs[0]
if self.output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
else:
tensor_normalized = self.layer_norm1[i](tensor)
attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i])
attn = attn_outputs[0]
if self.output_attentions:
attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
# encoder attention (for decoder only)
# if self.is_decoder and src_enc is not None:
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
# attn = F.dropout(attn, p=self.dropout, training=self.training)
# tensor = tensor + attn
# tensor = self.layer_norm15[i](tensor)
# FFN
if not self.pre_norm:
tensor = tensor + self.ffns[i](tensor)
tensor = self.layer_norm2[i](tensor)
else:
tensor_normalized = self.layer_norm2[i](tensor)
tensor = tensor + self.ffns[i](tensor_normalized)
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# Add last hidden state
if self.output_hidden_states:
hidden_states = hidden_states + (tensor,)
# update cache length
if cache is not None:
cache["slen"] += tensor.size(1)
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if self.output_hidden_states:
outputs = outputs + (hidden_states,)
if self.output_attentions:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
@add_start_docstrings(
"""The Flaubert Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
FLAUBERT_START_DOCSTRING,
FLAUBERT_INPUTS_DOCSTRING,
)
class FlaubertWithLMHeadModel(XLMWithLMHeadModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-cased')
model = FlaubertWithLMHeadModel.from_pretrained('flaubert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Le chat manges une pomme.", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
config_class = FlaubertConfig
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config):
super(FlaubertWithLMHeadModel, self).__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()
@add_start_docstrings(
"""Flaubert Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
FLAUBERT_START_DOCSTRING,
FLAUBERT_INPUTS_DOCSTRING,
)
class FlaubertForSequenceClassification(XLMForSequenceClassification):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-cased')
model = FlaubertForSequenceClassification.from_pretrained('flaubert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Le chat manges une pomme.", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
config_class = FlaubertConfig
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config):
super(FlaubertForSequenceClassification, self).__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()
@add_start_docstrings(
"""Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
FLAUBERT_START_DOCSTRING,
FLAUBERT_INPUTS_DOCSTRING,
)
class FlaubertForQuestionAnsweringSimple(XLMForQuestionAnsweringSimple):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels whether a question has an answer or no answer (SQuAD 2.0)
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-cased')
model = FlaubertForQuestionAnsweringSimple.from_pretrained('flaubert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Le chat manges une pomme", add_special_tokens=True)).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
config_class = FlaubertConfig
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config):
super(FlaubertForQuestionAnsweringSimple, self).__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()
@add_start_docstrings(
"""Flaubert Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
FLAUBERT_START_DOCSTRING,
FLAUBERT_INPUTS_DOCSTRING,
)
class FlaubertForQuestionAnswering(XLMForQuestionAnswering):
r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels whether a question has an answer or no answer (SQuAD 2.0)
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-cased')
model = FlaubertForQuestionAnswering.from_pretrained('flaubert-base-cased')
input_ids = torch.tensor(tokenizer.encode("Le chat manges une pomme.", add_special_tokens=True)).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
"""
config_class = FlaubertConfig
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
def __init__(self, config):
super(FlaubertForQuestionAnswering, self).__init__(config)
self.transformer = FlaubertModel(config)
self.init_weights()

View File

@ -0,0 +1,145 @@
# coding=utf-8
# Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Flaubert, based on XLM."""
import logging
import unicodedata
import six
from .tokenization_xlm import XLMTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/vocab.json",
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/vocab.json",
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/vocab.json",
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/vocab.json",
},
"merges_file": {
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/merges.txt",
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/merges.txt",
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/merges.txt",
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"flaubert-small-cased": 512,
"flaubert-base-uncased": 512,
"flaubert-base-cased": 512,
"flaubert-large-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"flaubert-small-cased": {"do_lowercase": False},
"flaubert-base-uncased": {"do_lowercase": True},
"flaubert-base-cased": {"do_lowercase": False},
"flaubert-large-cased": {"do_lowercase": False},
}
def convert_to_unicode(text):
"""
Converts `text` to Unicode (if it's not already), assuming UTF-8 input.
"""
# six_ensure_text is copied from https://github.com/benjaminp/six
def six_ensure_text(s, encoding="utf-8", errors="strict"):
if isinstance(s, six.binary_type):
return s.decode(encoding, errors)
elif isinstance(s, six.text_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
return six_ensure_text(text, encoding="utf-8", errors="ignore")
class FlaubertTokenizer(XLMTokenizer):
"""
BPE tokenizer for Flaubert
- Moses preprocessing & tokenization
- Normalize all inputs text
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
(ex: "__classify__") to a vocabulary
- `do_lowercase` controle lower casing (automatically set for pretrained vocabularies)
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, do_lowercase=False, **kwargs):
super().__init__(**kwargs)
self.do_lowercase = do_lowercase
self.do_lowercase_and_remove_accent = False
def preprocess_text(self, text):
text = text.replace("``", '"').replace("''", '"')
text = convert_to_unicode(text)
text = unicodedata.normalize("NFC", text)
if self.do_lowercase:
text = text.lower()
return text
def _tokenize(self, text, bypass_tokenizer=False):
"""
Tokenize a string given language code using Moses.
Details of tokenization:
- [sacremoses](https://github.com/alvations/sacremoses): port of Moses
- Install with `pip install sacremoses`
Args:
- bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE.
Returns:
List of tokens.
"""
lang = "fr"
if lang and self.lang2id and lang not in self.lang2id:
logger.error(
"Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model."
)
if bypass_tokenizer:
text = text.split()
else:
text = self.preprocess_text(text)
text = self.moses_pipeline(text, lang=lang)
text = self.moses_tokenize(text, lang=lang)
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")])
return split_tokens