mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1860 from stefan-it/camembert-for-token-classification
[WIP] Add support for CamembertForTokenClassification
This commit is contained in:
commit
0cdfcca24b
@ -37,6 +37,7 @@ from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
||||
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
||||
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
|
||||
from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -47,7 +48,8 @@ ALL_MODELS = sum(
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
|
@ -100,6 +100,7 @@ if is_torch_available():
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_camembert import (CamembertForMaskedLM, CamembertModel,
|
||||
CamembertForSequenceClassification, CamembertForMultipleChoice,
|
||||
CamembertForTokenClassification,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
|
||||
|
@ -20,7 +20,7 @@ from __future__ import (absolute_import, division, print_function,
|
||||
|
||||
import logging
|
||||
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice
|
||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification
|
||||
from .configuration_camembert import CamembertConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
@ -255,3 +255,39 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
||||
"""
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings("""CamemBERT Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING)
|
||||
class CamembertForTokenClassification(RobertaForTokenClassification):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the token classification loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
|
||||
model = CamembertForTokenClassification.from_pretrained('camembert-base')
|
||||
input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, scores = outputs[:2]
|
||||
|
||||
"""
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
@ -16,9 +16,14 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
|
||||
|
||||
@ -55,6 +60,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(str(vocab_file))
|
||||
self.vocab_file = vocab_file
|
||||
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
|
||||
# sentencepiece vocabulary (this is the case for <s> and </s>
|
||||
self.fairseq_tokens_to_ids = {'<s>NOTUSED': 0, '<pad>': 1, '</s>NOTUSED': 2, '<unk>': 3}
|
||||
@ -135,3 +141,17 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
||||
if index in self.fairseq_ids_to_tokens:
|
||||
return self.fairseq_ids_to_tokens[index]
|
||||
return self.sp_model.IdToPiece(index - self.fairseq_offset)
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
Loading…
Reference in New Issue
Block a user