Ctrl for sequence classification (#8812)

* add CTRLForSequenceClassification

* pass local test

* merge with master

* fix modeling test for sequence classification

* fix deco

* fix assert
This commit is contained in:
elk-cloner 2020-12-01 12:19:27 +03:30 committed by GitHub
parent 7f34d75780
commit 4a9e502a36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 167 additions and 7 deletions

View File

@ -65,6 +65,13 @@ CTRLLMHeadModel
:members: forward
CTRLForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.CTRLForSequenceClassification
:members: forward
TFCTRLModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -391,7 +391,13 @@ if is_torch_available():
CamembertForTokenClassification,
CamembertModel,
)
from .models.ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
from .models.ctrl import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
CTRLPreTrainedModel,
)
from .models.deberta import (
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
DebertaForSequenceClassification,

View File

@ -60,7 +60,7 @@ from ..camembert.modeling_camembert import (
CamembertForTokenClassification,
CamembertModel,
)
from ..ctrl.modeling_ctrl import CTRLLMHeadModel, CTRLModel
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel
from ..distilbert.modeling_distilbert import (
DistilBertForMaskedLM,
@ -415,6 +415,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(GPT2Config, GPT2ForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification),
]
)

View File

@ -8,7 +8,13 @@ from .tokenization_ctrl import CTRLTokenizer
if is_torch_available():
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
from .modeling_ctrl import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
CTRLPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_ctrl import (

View File

@ -18,10 +18,10 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_ctrl import CTRLConfig
@ -571,3 +571,117 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The CTRL Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.CTRLForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the
position of the last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that
is not a padding token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each
row of the batch. Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of
:obj:`input_ids`, it does the same (take the last value in each row of the batch).
""",
CTRL_START_DOCSTRING,
)
class CTRLForSequenceClassification(CTRLPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = CTRLModel(config)
self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.classifier(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjuction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=pooled_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -634,6 +634,15 @@ class CamembertModel:
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
class CTRLForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class CTRLLMHeadModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)

View File

@ -26,7 +26,13 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
from transformers import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLConfig, CTRLLMHeadModel, CTRLModel
from transformers import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLConfig,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
)
class CTRLModelTester:
@ -57,6 +63,7 @@ class CTRLModelTester:
self.num_labels = 3
self.num_choices = 4
self.scope = None
self.pad_token_id = self.vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -94,6 +101,7 @@ class CTRLModelTester:
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@ -149,11 +157,20 @@ class CTRLModelTester:
return config, inputs_dict
def create_and_check_ctrl_for_sequence_classification(self, config, input_ids, head_mask, token_type_ids, *args):
config.num_labels = self.num_labels
model = CTRLForSequenceClassification(config)
model.to(torch_device)
model.eval()
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
@require_torch
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
test_pruning = True
test_torchscript = False