Merge pull request #1870 from alexzubiaga/xlnet-for-token-classification

XLNet for Token classification
This commit is contained in:
Thomas Wolf 2019-12-05 09:54:09 +01:00 committed by GitHub
commit d425a4d60b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 232 additions and 11 deletions

View File

@ -86,9 +86,10 @@ if is_torch_available():
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForMultipleChoice,
XLNetForQuestionAnsweringSimple, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
XLNetForSequenceClassification, XLNetForTokenClassification,
XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple,
XLNetForQuestionAnswering, load_tf_weights_in_xlnet,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
@ -144,6 +145,7 @@ if is_tf_available():
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer,
TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)

View File

@ -938,6 +938,59 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
return outputs # return logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XLNetTokenizer, TFXLNetForTokenClassification
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFXLNetForTokenClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.transformer = TFXLNetMainLayer(config, name='transformer')
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs):
transformer_outputs = self.transformer(inputs, **kwargs)
output = transformer_outputs[0]
logits = self.classifier(output)
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, (mems), (hidden states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)

View File

@ -1046,6 +1046,106 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLNET_START_DOCSTRING,
XLNET_INPUTS_DOCSTRING)
class XLNetForTokenClassification(XLNetPreTrainedModel):
r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
scores = outputs[0]
"""
def __init__(self, config):
super(XLNetForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
self.transformer = XLNetModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[1:] # Keep mems, hidden states, attentions if there are in it
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)

View File

@ -30,6 +30,7 @@ if is_tf_available():
from transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
@ -42,6 +43,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple) if is_tf_available() else ()
test_pruning = False
@ -258,6 +260,26 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_for_token_classification(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
config.num_labels = input_ids_1.shape[1]
model = TFXLNetForTokenClassification(config)
inputs = {'input_ids': input_ids_1,
'attention_mask': input_mask,
# 'token_type_ids': token_type_ids
}
logits, mems_1 = model(inputs)
result = {
"mems_1": [mem.numpy() for mem in mems_1],
"logits": logits.numpy(),
}
self.parent.assertListEqual(
list(result["logits"].shape),
[self.batch_size, self.seq_length, config.num_labels])
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
@ -289,6 +311,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
def test_xlnet_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_for_token_classification(*config_and_inputs)
def test_xlnet_qa(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -28,7 +28,8 @@ from transformers import is_torch_available
if is_torch_available():
import torch
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification,
XLNetForTokenClassification, XLNetForQuestionAnswering)
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
@ -38,7 +39,7 @@ from .configuration_common_test import ConfigTester
class XLNetModelTest(CommonTestCases.CommonModelTester):
all_model_classes=(XLNetModel, XLNetLMHeadModel,
all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification,
XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else ()
test_pruning = False
@ -107,10 +108,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
sequence_labels = None
lm_labels = None
is_impossible_labels = None
token_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
@ -129,14 +132,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
num_labels=self.type_sequence_label_size)
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels)
def set_seed(self):
random.seed(self.seed)
torch.manual_seed(self.seed)
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetModel(config)
model.eval()
@ -164,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetLMHeadModel(config)
model.eval()
@ -204,7 +207,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetForQuestionAnswering(config)
model.eval()
@ -261,8 +264,40 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetForTokenClassification(config)
model.eval()
logits, mems_1 = model(input_ids_1)
loss, logits, mems_1 = model(input_ids_1, labels=token_labels)
result = {
"loss": loss,
"mems_1": mems_1,
"logits": logits,
}
self.parent.assertListEqual(
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.type_sequence_label_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels,
sequence_labels, is_impossible_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetForSequenceClassification(config)
model.eval()
@ -289,7 +324,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels,
sequence_labels, is_impossible_labels) = config_and_inputs
sequence_labels, is_impossible_labels, token_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids_1}
return config, inputs_dict
@ -316,6 +351,11 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
def test_xlnet_token_classif(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_token_classif(*config_and_inputs)
def test_xlnet_qa(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()