mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Merge pull request #1870 from alexzubiaga/xlnet-for-token-classification
XLNet for Token classification
This commit is contained in:
commit
d425a4d60b
@ -86,9 +86,10 @@ if is_torch_available():
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||
XLNetForSequenceClassification, XLNetForMultipleChoice,
|
||||
XLNetForQuestionAnsweringSimple, XLNetForQuestionAnswering,
|
||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
XLNetForSequenceClassification, XLNetForTokenClassification,
|
||||
XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple,
|
||||
XLNetForQuestionAnswering, load_tf_weights_in_xlnet,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
|
||||
@ -144,6 +145,7 @@ if is_tf_available():
|
||||
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
|
||||
TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
|
@ -938,6 +938,59 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
|
||||
return outputs # return logits, (mems), (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``tf.Tensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
See details in the docstring of the `mems` input above.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import XLNetTokenizer, TFXLNetForTokenClassification
|
||||
|
||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
scores = outputs[0]
|
||||
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFXLNetForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = TFXLNetMainLayer(config, name='transformer')
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
||||
output = transformer_outputs[0]
|
||||
|
||||
logits = self.classifier(output)
|
||||
|
||||
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
|
||||
return outputs # return logits, (mems), (hidden states), (attentions)
|
||||
|
||||
|
||||
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
# the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
|
@ -1046,6 +1046,106 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
XLNET_START_DOCSTRING,
|
||||
XLNET_INPUTS_DOCSTRING)
|
||||
class XLNetForTokenClassification(XLNetPreTrainedModel):
|
||||
r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the multiple choice classification loss.
|
||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above)
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss.
|
||||
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
See details in the docstring of the `mems` input above.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||
model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
scores = outputs[0]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLNetForTokenClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = XLNetModel(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
|
||||
outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
mems=mems,
|
||||
perm_mask=perm_mask,
|
||||
target_mapping=target_mapping,
|
||||
token_type_ids=token_type_ids,
|
||||
input_mask=input_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
|
||||
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
|
||||
|
@ -30,6 +30,7 @@ if is_tf_available():
|
||||
|
||||
from transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
@ -42,6 +43,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple) if is_tf_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@ -258,6 +260,26 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_for_token_classification(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
config.num_labels = input_ids_1.shape[1]
|
||||
model = TFXLNetForTokenClassification(config)
|
||||
inputs = {'input_ids': input_ids_1,
|
||||
'attention_mask': input_mask,
|
||||
# 'token_type_ids': token_type_ids
|
||||
}
|
||||
logits, mems_1 = model(inputs)
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape),
|
||||
[self.batch_size, self.seq_length, config.num_labels])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
@ -289,6 +311,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_xlnet_qa(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -28,7 +28,8 @@ from transformers import is_torch_available
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification,
|
||||
XLNetForTokenClassification, XLNetForQuestionAnswering)
|
||||
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
@ -38,7 +39,7 @@ from .configuration_common_test import ConfigTester
|
||||
|
||||
class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes=(XLNetModel, XLNetLMHeadModel,
|
||||
all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification,
|
||||
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@ -107,10 +108,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
sequence_labels = None
|
||||
lm_labels = None
|
||||
is_impossible_labels = None
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
config = XLNetConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
@ -129,14 +132,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
num_labels=self.type_sequence_label_size)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetModel(config)
|
||||
model.eval()
|
||||
|
||||
@ -164,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetLMHeadModel(config)
|
||||
model.eval()
|
||||
|
||||
@ -204,7 +207,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetForQuestionAnswering(config)
|
||||
model.eval()
|
||||
|
||||
@ -261,8 +264,40 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetForTokenClassification(config)
|
||||
model.eval()
|
||||
|
||||
logits, mems_1 = model(input_ids_1)
|
||||
loss, logits, mems_1 = model(input_ids_1, labels=token_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"mems_1": mems_1,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()),
|
||||
[self.batch_size, self.seq_length, self.type_sequence_label_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels,
|
||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||
model = XLNetForSequenceClassification(config)
|
||||
model.eval()
|
||||
|
||||
@ -289,7 +324,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels,
|
||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||
sequence_labels, is_impossible_labels, token_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
@ -316,6 +351,11 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_token_classif(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_token_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_qa(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
Loading…
Reference in New Issue
Block a user