mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Ctrl for sequence classification (#8812)
* add CTRLForSequenceClassification * pass local test * merge with master * fix modeling test for sequence classification * fix deco * fix assert
This commit is contained in:
parent
7f34d75780
commit
4a9e502a36
@ -65,6 +65,13 @@ CTRLLMHeadModel
|
||||
:members: forward
|
||||
|
||||
|
||||
CTRLForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.CTRLForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TFCTRLModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -391,7 +391,13 @@ if is_torch_available():
|
||||
CamembertForTokenClassification,
|
||||
CamembertModel,
|
||||
)
|
||||
from .models.ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
|
||||
from .models.ctrl import (
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CTRLForSequenceClassification,
|
||||
CTRLLMHeadModel,
|
||||
CTRLModel,
|
||||
CTRLPreTrainedModel,
|
||||
)
|
||||
from .models.deberta import (
|
||||
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DebertaForSequenceClassification,
|
||||
|
@ -60,7 +60,7 @@ from ..camembert.modeling_camembert import (
|
||||
CamembertForTokenClassification,
|
||||
CamembertModel,
|
||||
)
|
||||
from ..ctrl.modeling_ctrl import CTRLLMHeadModel, CTRLModel
|
||||
from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel
|
||||
from ..deberta.modeling_deberta import DebertaForSequenceClassification, DebertaModel
|
||||
from ..distilbert.modeling_distilbert import (
|
||||
DistilBertForMaskedLM,
|
||||
@ -415,6 +415,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(GPT2Config, GPT2ForSequenceClassification),
|
||||
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
||||
(ReformerConfig, ReformerForSequenceClassification),
|
||||
(CTRLConfig, CTRLForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,13 @@ from .tokenization_ctrl import CTRLTokenizer
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel
|
||||
from .modeling_ctrl import (
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CTRLForSequenceClassification,
|
||||
CTRLLMHeadModel,
|
||||
CTRLModel,
|
||||
CTRLPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_ctrl import (
|
||||
|
@ -18,10 +18,10 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
|
||||
from ...modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
@ -571,3 +571,117 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The CTRL Model transformer with a sequence classification head on top (linear layer).
|
||||
:class:`~transformers.CTRLForSequenceClassification` uses the last token in order to do the classification, as
|
||||
other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the
|
||||
position of the last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that
|
||||
is not a padding token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each
|
||||
row of the batch. Since it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of
|
||||
:obj:`input_ids`, it does the same (take the last value in each row of the batch).
|
||||
""",
|
||||
CTRL_START_DOCSTRING,
|
||||
)
|
||||
class CTRLForSequenceClassification(CTRLPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = CTRLModel(config)
|
||||
self.classifier = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="ctrl",
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
f"unexpected if using padding tokens in conjuction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
@ -634,6 +634,15 @@ class CamembertModel:
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class CTRLForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class CTRLLMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -26,7 +26,13 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLConfig, CTRLLMHeadModel, CTRLModel
|
||||
from transformers import (
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CTRLConfig,
|
||||
CTRLForSequenceClassification,
|
||||
CTRLLMHeadModel,
|
||||
CTRLModel,
|
||||
)
|
||||
|
||||
|
||||
class CTRLModelTester:
|
||||
@ -57,6 +63,7 @@ class CTRLModelTester:
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
self.pad_token_id = self.vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -94,6 +101,7 @@ class CTRLModelTester:
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@ -149,11 +157,20 @@ class CTRLModelTester:
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_ctrl_for_sequence_classification(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
config.num_labels = self.num_labels
|
||||
model = CTRLForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
|
||||
@require_torch
|
||||
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||
all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
|
||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = True
|
||||
test_torchscript = False
|
||||
|
Loading…
Reference in New Issue
Block a user