mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add GPTNeoXForSequenceClassification (#22671)
* add GPTNeoXForSequenceClassification * move the labels to logits.device (ref: #22561) * fix
This commit is contained in:
parent
f74b40208d
commit
6daa9cb515
@ -78,3 +78,8 @@ The `generate()` method can be used to generate text using GPT Neo model.
|
||||
|
||||
[[autodoc]] GPTNeoXForCausalLM
|
||||
- forward
|
||||
|
||||
## GPTNeoXForSequenceClassification
|
||||
|
||||
[[autodoc]] GPTNeoXForSequenceClassification
|
||||
- forward
|
@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
@ -1666,6 +1666,7 @@ else:
|
||||
[
|
||||
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPTNeoXForCausalLM",
|
||||
"GPTNeoXForSequenceClassification",
|
||||
"GPTNeoXLayer",
|
||||
"GPTNeoXModel",
|
||||
"GPTNeoXPreTrainedModel",
|
||||
@ -5164,6 +5165,7 @@ if TYPE_CHECKING:
|
||||
from .models.gpt_neox import (
|
||||
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPTNeoXForCausalLM,
|
||||
GPTNeoXForSequenceClassification,
|
||||
GPTNeoXLayer,
|
||||
GPTNeoXModel,
|
||||
GPTNeoXPreTrainedModel,
|
||||
|
@ -659,6 +659,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("gpt2", "GPT2ForSequenceClassification"),
|
||||
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
|
||||
("gpt_neo", "GPTNeoForSequenceClassification"),
|
||||
("gpt_neox", "GPTNeoXForSequenceClassification"),
|
||||
("gptj", "GPTJForSequenceClassification"),
|
||||
("ibert", "IBertForSequenceClassification"),
|
||||
("layoutlm", "LayoutLMForSequenceClassification"),
|
||||
|
@ -36,6 +36,7 @@ else:
|
||||
_import_structure["modeling_gpt_neox"] = [
|
||||
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"GPTNeoXForCausalLM",
|
||||
"GPTNeoXForSequenceClassification",
|
||||
"GPTNeoXLayer",
|
||||
"GPTNeoXModel",
|
||||
"GPTNeoXPreTrainedModel",
|
||||
@ -62,6 +63,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_gpt_neox import (
|
||||
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPTNeoXForCausalLM,
|
||||
GPTNeoXForSequenceClassification,
|
||||
GPTNeoXLayer,
|
||||
GPTNeoXModel,
|
||||
GPTNeoXPreTrainedModel,
|
||||
|
@ -19,7 +19,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
@ -28,7 +28,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_gpt_neox import GPTNeoXConfig
|
||||
@ -730,3 +730,131 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
||||
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPTNeoX Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`GPTNeoXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-1) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
""",
|
||||
GPT_NEOX_START_DOCSTRING,
|
||||
)
|
||||
class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.gpt_neox(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
print(pooled_logits.shape, labels.shape)
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -3236,6 +3236,13 @@ class GPTNeoXForCausalLM(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GPTNeoXForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GPTNeoXLayer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -29,7 +29,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GPTNeoXForCausalLM, GPTNeoXModel
|
||||
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel
|
||||
|
||||
|
||||
class GPTNeoXModelTester:
|
||||
@ -80,6 +80,7 @@ class GPTNeoXModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.pad_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -110,6 +111,7 @@ class GPTNeoXModelTester:
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
@ -142,6 +144,15 @@ class GPTNeoXModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = GPTNeoXForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
|
||||
config.is_decoder = True
|
||||
model = GPTNeoXForCausalLM(config=config)
|
||||
@ -188,10 +199,19 @@ class GPTNeoXModelTester:
|
||||
|
||||
@require_torch
|
||||
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": GPTNeoXModel, "text-generation": GPTNeoXForCausalLM} if is_torch_available() else {}
|
||||
{
|
||||
"feature-extraction": GPTNeoXModel,
|
||||
"text-classification": GPTNeoXForSequenceClassification,
|
||||
"text-generation": GPTNeoXForCausalLM,
|
||||
"zero-shot": GPTNeoXForSequenceClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
@ -229,6 +249,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_model_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user