add GPTNeoXForSequenceClassification (#22671)

* add GPTNeoXForSequenceClassification

* move the labels to logits.device (ref: #22561)

* fix
This commit is contained in:
Sugawara 2023-04-11 00:52:23 +09:00 committed by GitHub
parent f74b40208d
commit 6daa9cb515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 6 deletions

View File

@ -78,3 +78,8 @@ The `generate()` method can be used to generate text using GPT Neo model.
[[autodoc]] GPTNeoXForCausalLM [[autodoc]] GPTNeoXForCausalLM
- forward - forward
## GPTNeoXForSequenceClassification
[[autodoc]] GPTNeoXForSequenceClassification
- forward

View File

@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!--> <!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) [ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
<!--End of the generated tip--> <!--End of the generated tip-->

View File

@ -1666,6 +1666,7 @@ else:
[ [
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification",
"GPTNeoXLayer", "GPTNeoXLayer",
"GPTNeoXModel", "GPTNeoXModel",
"GPTNeoXPreTrainedModel", "GPTNeoXPreTrainedModel",
@ -5164,6 +5165,7 @@ if TYPE_CHECKING:
from .models.gpt_neox import ( from .models.gpt_neox import (
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXLayer, GPTNeoXLayer,
GPTNeoXModel, GPTNeoXModel,
GPTNeoXPreTrainedModel, GPTNeoXPreTrainedModel,

View File

@ -659,6 +659,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gpt2", "GPT2ForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"), ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
("gpt_neo", "GPTNeoForSequenceClassification"), ("gpt_neo", "GPTNeoForSequenceClassification"),
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"), ("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"), ("ibert", "IBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"), ("layoutlm", "LayoutLMForSequenceClassification"),

View File

@ -36,6 +36,7 @@ else:
_import_structure["modeling_gpt_neox"] = [ _import_structure["modeling_gpt_neox"] = [
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification",
"GPTNeoXLayer", "GPTNeoXLayer",
"GPTNeoXModel", "GPTNeoXModel",
"GPTNeoXPreTrainedModel", "GPTNeoXPreTrainedModel",
@ -62,6 +63,7 @@ if TYPE_CHECKING:
from .modeling_gpt_neox import ( from .modeling_gpt_neox import (
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXLayer, GPTNeoXLayer,
GPTNeoXModel, GPTNeoXModel,
GPTNeoXPreTrainedModel, GPTNeoXPreTrainedModel,

View File

@ -19,7 +19,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
@ -28,7 +28,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_gpt_neox import GPTNeoXConfig from .configuration_gpt_neox import GPTNeoXConfig
@ -730,3 +730,131 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
) )
return reordered_past return reordered_past
@add_start_docstrings(
"""
The GPTNeoX Model transformer with a sequence classification head on top (linear layer).
[`GPTNeoXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
GPT_NEOX_START_DOCSTRING,
)
class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.gpt_neox = GPTNeoXModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
print(pooled_logits.shape, labels.shape)
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -3236,6 +3236,13 @@ class GPTNeoXForCausalLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class GPTNeoXForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTNeoXLayer(metaclass=DummyObject): class GPTNeoXLayer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

@ -29,7 +29,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXModel from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel
class GPTNeoXModelTester: class GPTNeoXModelTester:
@ -80,6 +80,7 @@ class GPTNeoXModelTester:
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices self.num_choices = num_choices
self.scope = scope self.scope = scope
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -110,6 +111,7 @@ class GPTNeoXModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
) )
def prepare_config_and_inputs_for_decoder(self): def prepare_config_and_inputs_for_decoder(self):
@ -142,6 +144,15 @@ class GPTNeoXModelTester:
result = model(input_ids, attention_mask=input_mask, labels=token_labels) result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForSequenceClassification(config)
model.to(torch_device)
model.eval()
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask): def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True config.is_decoder = True
model = GPTNeoXForCausalLM(config=config) model = GPTNeoXForCausalLM(config=config)
@ -188,10 +199,19 @@ class GPTNeoXModelTester:
@require_torch @require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else () all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else () all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": GPTNeoXModel, "text-generation": GPTNeoXForCausalLM} if is_torch_available() else {} {
"feature-extraction": GPTNeoXModel,
"text-classification": GPTNeoXForSequenceClassification,
"text-generation": GPTNeoXForCausalLM,
"zero-shot": GPTNeoXForSequenceClassification,
}
if is_torch_available()
else {}
) )
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
@ -229,6 +249,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_model_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
@unittest.skip(reason="Feed forward chunking is not implemented") @unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass