diff --git a/transformers/__init__.py b/transformers/__init__.py index 970bdf0cf1d..f06ee3f35df 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -86,9 +86,10 @@ if is_torch_available(): CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, - XLNetForSequenceClassification, XLNetForMultipleChoice, - XLNetForQuestionAnsweringSimple, XLNetForQuestionAnswering, - load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) + XLNetForSequenceClassification, XLNetForTokenClassification, + XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple, + XLNetForQuestionAnswering, load_tf_weights_in_xlnet, + XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_xlm import (XLMPreTrainedModel , XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering, XLMForQuestionAnsweringSimple, @@ -144,6 +145,7 @@ if is_tf_available(): from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer, TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, + TFXLNetForTokenClassification, TFXLNetForQuestionAnsweringSimple, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) diff --git a/transformers/modeling_tf_xlnet.py b/transformers/modeling_tf_xlnet.py index 215d906f572..759b57d8351 100644 --- a/transformers/modeling_tf_xlnet.py +++ b/transformers/modeling_tf_xlnet.py @@ -938,6 +938,59 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel): return outputs # return logits, (mems), (hidden states), (attentions) +@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) +class TFXLNetForTokenClassification(TFXLNetPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` + Classification scores (before SoftMax). + **mems**: (`optional`, returned when ``config.mem_len > 0``) + list of ``tf.Tensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. + See details in the docstring of the `mems` input above. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + import tensorflow as tf + from transformers import XLNetTokenizer, TFXLNetForTokenClassification + + tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') + model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased') + input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 + outputs = model(input_ids) + scores = outputs[0] + + """ + def __init__(self, config, *inputs, **kwargs): + super(TFXLNetForTokenClassification, self).__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.transformer = TFXLNetMainLayer(config, name='transformer') + self.classifier = tf.keras.layers.Dense(config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name='classifier') + + def call(self, inputs, **kwargs): + transformer_outputs = self.transformer(inputs, **kwargs) + output = transformer_outputs[0] + + logits = self.classifier(output) + + outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it + + return outputs # return logits, (mems), (hidden states), (attentions) + + # @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of # the hidden-states output to compute `span start logits` and `span end logits`). """, # XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index 658048a660b..2f4f8839056 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -1046,6 +1046,106 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): return outputs # return (loss), logits, (mems), (hidden states), (attentions) +@add_start_docstrings("""XLNet Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + XLNET_START_DOCSTRING, + XLNET_INPUTS_DOCSTRING) +class XLNetForTokenClassification(XLNetPreTrainedModel): + r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + The second dimension of the input (`num_choices`) indicates the number of choices to scores. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: + Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` + Classification scores (before SoftMax). + **mems**: (`optional`, returned when ``config.mem_len > 0``) + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context. + See details in the docstring of the `mems` input above. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased') + model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + scores = outputs[0] + + """ + def __init__(self, config): + super(XLNetForTokenClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.transformer = XLNetModel(config) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, + token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None): + + outputs = self.transformer(input_ids, + attention_mask=attention_mask, + mems=mems, + perm_mask=perm_mask, + target_mapping=target_mapping, + token_type_ids=token_type_ids, + input_mask=input_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[1:] # Keep mems, hidden states, attentions if there are in it + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # return (loss), logits, (mems), (hidden states), (attentions) + + @add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RACE/SWAG tasks. """, XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING) diff --git a/transformers/tests/modeling_tf_xlnet_test.py b/transformers/tests/modeling_tf_xlnet_test.py index 12a8fbe36f0..a00a965570c 100644 --- a/transformers/tests/modeling_tf_xlnet_test.py +++ b/transformers/tests/modeling_tf_xlnet_test.py @@ -30,6 +30,7 @@ if is_tf_available(): from transformers.modeling_tf_xlnet import (TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, + TFXLNetForTokenClassification, TFXLNetForQuestionAnsweringSimple, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) else: @@ -42,6 +43,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, + TFXLNetForTokenClassification, TFXLNetForQuestionAnsweringSimple) if is_tf_available() else () test_pruning = False @@ -258,6 +260,26 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): list(list(mem.shape) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + def create_and_check_xlnet_for_token_classification(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + config.num_labels = input_ids_1.shape[1] + model = TFXLNetForTokenClassification(config) + inputs = {'input_ids': input_ids_1, + 'attention_mask': input_mask, + # 'token_type_ids': token_type_ids + } + logits, mems_1 = model(inputs) + result = { + "mems_1": [mem.numpy() for mem in mems_1], + "logits": logits.numpy(), + } + self.parent.assertListEqual( + list(result["logits"].shape), + [self.batch_size, self.seq_length, config.num_labels]) + self.parent.assertListEqual( + list(list(mem.shape) for mem in result["mems_1"]), + [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, @@ -289,6 +311,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs) + def test_xlnet_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_for_token_classification(*config_and_inputs) + def test_xlnet_qa(self): self.model_tester.set_seed() config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index d97ea6a425d..8f35d34e147 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -28,7 +28,8 @@ from transformers import is_torch_available if is_torch_available(): import torch - from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) + from transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, + XLNetForTokenClassification, XLNetForQuestionAnswering) from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP else: pytestmark = pytest.mark.skip("Require Torch") @@ -38,7 +39,7 @@ from .configuration_common_test import ConfigTester class XLNetModelTest(CommonTestCases.CommonModelTester): - all_model_classes=(XLNetModel, XLNetLMHeadModel, + all_model_classes=(XLNetModel, XLNetLMHeadModel, XLNetForTokenClassification, XLNetForSequenceClassification, XLNetForQuestionAnswering) if is_torch_available() else () test_pruning = False @@ -107,10 +108,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): sequence_labels = None lm_labels = None is_impossible_labels = None + token_labels = None if self.use_labels: lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) is_impossible_labels = ids_tensor([self.batch_size], 2).float() + token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) config = XLNetConfig( vocab_size_or_config_json_file=self.vocab_size, @@ -129,14 +132,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): num_labels=self.type_sequence_label_size) return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, - target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels) + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels) def set_seed(self): random.seed(self.seed) torch.manual_seed(self.seed) def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, - target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetModel(config) model.eval() @@ -164,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, - target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetLMHeadModel(config) model.eval() @@ -204,7 +207,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, - target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetForQuestionAnswering(config) model.eval() @@ -261,8 +264,40 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): list(list(mem.size()) for mem in result["mems"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + def create_and_check_xlnet_token_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): + model = XLNetForTokenClassification(config) + model.eval() + + logits, mems_1 = model(input_ids_1) + loss, logits, mems_1 = model(input_ids_1, labels=token_labels) + + result = { + "loss": loss, + "mems_1": mems_1, + "logits": logits, + } + + self.parent.assertListEqual( + list(result["loss"].size()), + []) + self.parent.assertListEqual( + list(result["logits"].size()), + [self.batch_size, self.seq_length, self.type_sequence_label_size]) + self.parent.assertListEqual( + list(list(mem.size()) for mem in result["mems_1"]), + [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, + sequence_labels, is_impossible_labels) = config_and_inputs + inputs_dict = {'input_ids': input_ids_1} + return config, inputs_dict + def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, - target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetForSequenceClassification(config) model.eval() @@ -289,7 +324,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, - sequence_labels, is_impossible_labels) = config_and_inputs + sequence_labels, is_impossible_labels, token_labels) = config_and_inputs inputs_dict = {'input_ids': input_ids_1} return config, inputs_dict @@ -316,6 +351,11 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs) + def test_xlnet_token_classif(self): + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlnet_token_classif(*config_and_inputs) + def test_xlnet_qa(self): self.model_tester.set_seed() config_and_inputs = self.model_tester.prepare_config_and_inputs()