mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Added Sequence Classification class in GPTNeo (#11906)
* seq classification changes * fix tests
This commit is contained in:
parent
80d712fac6
commit
e1205e478a
1
datasets
Submodule
1
datasets
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit d95b95f8cf3cb0cff5f77a675139b584dcfcf719
|
@ -65,3 +65,9 @@ GPTNeoForCausalLM
|
||||
|
||||
.. autoclass:: transformers.GPTNeoForCausalLM
|
||||
:members: forward
|
||||
|
||||
GPTNeoForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.GPTNeoForSequenceClassification
|
||||
:members: forward
|
||||
|
@ -746,6 +746,7 @@ if is_torch_available():
|
||||
[
|
||||
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPTNeoForCausalLM",
|
||||
"GPTNeoForSequenceClassification",
|
||||
"GPTNeoModel",
|
||||
"GPTNeoPreTrainedModel",
|
||||
"load_tf_weights_in_gpt_neo",
|
||||
@ -2129,6 +2130,7 @@ if TYPE_CHECKING:
|
||||
from .models.gpt_neo import (
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoForSequenceClassification,
|
||||
GPTNeoModel,
|
||||
GPTNeoPreTrainedModel,
|
||||
load_tf_weights_in_gpt_neo,
|
||||
|
@ -145,7 +145,7 @@ from ..funnel.modeling_funnel import (
|
||||
FunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
|
||||
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoModel
|
||||
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel
|
||||
from ..ibert.modeling_ibert import (
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
@ -632,6 +632,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(DebertaConfig, DebertaForSequenceClassification),
|
||||
(DebertaV2Config, DebertaV2ForSequenceClassification),
|
||||
(GPT2Config, GPT2ForSequenceClassification),
|
||||
(GPTNeoConfig, GPTNeoForSequenceClassification),
|
||||
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
|
||||
(ReformerConfig, ReformerForSequenceClassification),
|
||||
(CTRLConfig, CTRLForSequenceClassification),
|
||||
|
@ -28,6 +28,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_gpt_neo"] = [
|
||||
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPTNeoForCausalLM",
|
||||
"GPTNeoForSequenceClassification",
|
||||
"GPTNeoModel",
|
||||
"GPTNeoPreTrainedModel",
|
||||
"load_tf_weights_in_gpt_neo",
|
||||
@ -41,6 +42,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_gpt_neo import (
|
||||
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoForSequenceClassification,
|
||||
GPTNeoModel,
|
||||
GPTNeoPreTrainedModel,
|
||||
load_tf_weights_in_gpt_neo,
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
@ -31,6 +31,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
@ -1027,3 +1028,120 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPTNeo Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
:class:`~transformers.GPTNeoForSequenceClassification` uses the last token in order to do the classification, as
|
||||
other causal models (e.g. GPT-1) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
|
||||
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
||||
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
|
||||
the last value in each row of the batch).
|
||||
""",
|
||||
GPT_NEO_START_DOCSTRING,
|
||||
)
|
||||
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = GPTNeoModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[range(batch_size), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
@ -1603,6 +1603,15 @@ class GPTNeoForCausalLM:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GPTNeoForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GPTNeoModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
@ -361,7 +361,6 @@ class GPT2ModelTester:
|
||||
model = GPT2ForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
print(config.num_labels, sequence_labels.size())
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
|
@ -34,6 +34,7 @@ if is_torch_available():
|
||||
GPT2Tokenizer,
|
||||
GPTNeoConfig,
|
||||
GPTNeoForCausalLM,
|
||||
GPTNeoForSequenceClassification,
|
||||
GPTNeoModel,
|
||||
)
|
||||
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
|
||||
@ -238,6 +239,16 @@ class GPTNeoModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt_neo_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTNeoForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPTNeoForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
@ -274,7 +285,9 @@ class GPTNeoModelTester:
|
||||
@require_torch
|
||||
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
|
||||
fx_ready_model_classes = all_model_classes
|
||||
test_missing_keys = False
|
||||
@ -305,6 +318,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_gpt_neo_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user