Merge pull request #1860 from stefan-it/camembert-for-token-classification

[WIP] Add support for CamembertForTokenClassification
This commit is contained in:
Thomas Wolf 2019-11-21 10:56:07 +01:00 committed by GitHub
commit 0cdfcca24b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 2 deletions

View File

@ -37,6 +37,7 @@ from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer
logger = logging.getLogger(__name__)
@ -47,7 +48,8 @@ ALL_MODELS = sum(
MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
}

View File

@ -100,6 +100,7 @@ if is_torch_available():
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_camembert import (CamembertForMaskedLM, CamembertModel,
CamembertForSequenceClassification, CamembertForMultipleChoice,
CamembertForTokenClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model

View File

@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
import logging
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings
@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
"""
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
@add_start_docstrings("""CamemBERT Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
class CamembertForTokenClassification(RobertaForTokenClassification):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForTokenClassification.from_pretrained('camembert-base')
input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
config_class = CamembertConfig
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP

View File

@ -16,9 +16,14 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import os
from shutil import copyfile
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s>
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)