From dcba9ee03b2129d1035a1d67a5fe52bca9ecbd49 Mon Sep 17 00:00:00 2001 From: Felipe Curti Date: Tue, 13 Oct 2020 06:06:15 -0300 Subject: [PATCH] Gpt1 for sequence classification (#7683) * Add Documentation for GPT-1 Classification * Add GPT-1 with Classification head * Add tests for GPT-1 Classification * Add GPT-1 For Classification to auto models * Remove authorized missing keys, change checkpoint to openai-gpt --- docs/source/model_doc/gpt.rst | 7 ++ src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 3 +- src/transformers/modeling_openai.py | 114 ++++++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 9 ++ tests/test_modeling_openai.py | 23 ++++- 6 files changed, 153 insertions(+), 4 deletions(-) diff --git a/docs/source/model_doc/gpt.rst b/docs/source/model_doc/gpt.rst index 07deb89b66d..5f945227a4b 100644 --- a/docs/source/model_doc/gpt.rst +++ b/docs/source/model_doc/gpt.rst @@ -104,6 +104,13 @@ OpenAIGPTDoubleHeadsModel :members: forward +OpenAIGPTForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.OpenAIGPTForSequenceClassification + :members: forward + + TFOpenAIGPTModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 143ff4187d4..4e6d8bdd344 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -437,6 +437,7 @@ if is_torch_available(): from .modeling_openai import ( OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, OpenAIGPTDoubleHeadsModel, + OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel, OpenAIGPTPreTrainedModel, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 5d7f4812136..5dd0201bbf4 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -153,7 +153,7 @@ from .modeling_mobilebert import ( MobileBertForTokenClassification, MobileBertModel, ) -from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel +from .modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_pegasus import PegasusForConditionalGeneration from .modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function RagModel, @@ -381,6 +381,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (FunnelConfig, FunnelForSequenceClassification), (DebertaConfig, DebertaForSequenceClassification), (GPT2Config, GPT2ForSequenceClassification), + (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), ] ) diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index 1645e6ed3f2..4f449c67a48 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -25,7 +25,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, MSELoss from .activations import gelu_new, swish from .configuration_openai import OpenAIGPTConfig @@ -36,7 +36,7 @@ from .file_utils import ( add_start_docstrings_to_callable, replace_return_docstrings, ) -from .modeling_outputs import BaseModelOutput, CausalLMOutput +from .modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from .modeling_utils import ( Conv1D, PreTrainedModel, @@ -732,3 +732,113 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """The Original OpenAI GPT Model transformer with a sequence classification head on top + (linear layer). + :class:`~transformers.OpenAIGPTForSequenceClassification` uses the last token in order to do the classification, as + other causal models (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. + If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token + in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. + Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it + does the same (take the last value in each row of the batch). + """, + OPENAI_GPT_START_DOCSTRING, +) +class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = OpenAIGPTModel(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + self.init_weights() + + @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="openai-gpt", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjuction with `inputs_embeds.`" + ) + + pooled_logits = logits[range(batch_size), sequence_lengths] + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(pooled_logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b8fb77b55df..e2c1154262b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1256,6 +1256,15 @@ class OpenAIGPTDoubleHeadsModel: requires_pytorch(self) +class OpenAIGPTForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class OpenAIGPTLMHeadModel: def __init__(self, *args, **kwargs): requires_pytorch(self) diff --git a/tests/test_modeling_openai.py b/tests/test_modeling_openai.py index 92a0335cda7..e74ce093fa1 100644 --- a/tests/test_modeling_openai.py +++ b/tests/test_modeling_openai.py @@ -30,6 +30,7 @@ if is_torch_available(): OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, OpenAIGPTConfig, OpenAIGPTDoubleHeadsModel, + OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel, ) @@ -61,6 +62,7 @@ class OpenAIGPTModelTester: self.num_labels = 3 self.num_choices = 4 self.scope = None + self.pad_token_id = self.vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -90,6 +92,7 @@ class OpenAIGPTModelTester: n_ctx=self.max_position_embeddings, # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range + pad_token_id=self.pad_token_id, return_dict=True, ) @@ -134,6 +137,18 @@ class OpenAIGPTModelTester: self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_openai_gpt_for_sequence_classification( + self, config, input_ids, head_mask, token_type_ids, *args + ): + config.num_labels = self.num_labels + model = OpenAIGPTForSequenceClassification(config) + model.to(torch_device) + model.eval() + # print(config.num_labels, sequence_labels.size()) + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -158,7 +173,9 @@ class OpenAIGPTModelTester: class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else () + (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification) + if is_torch_available() + else () ) all_generative_model_classes = ( (OpenAIGPTLMHeadModel,) if is_torch_available() else () @@ -183,6 +200,10 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_double_lm_head_model(*config_and_inputs) + def test_openai_gpt_classification_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: