Added Sequence Classification class in GPTNeo (#11906)

* seq classification changes

* fix tests
This commit is contained in:
Bhadresh Savani 2021-05-28 15:57:02 +05:30 committed by GitHub
parent 80d712fac6
commit e1205e478a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 159 additions and 4 deletions

1
datasets Submodule

@ -0,0 +1 @@
Subproject commit d95b95f8cf3cb0cff5f77a675139b584dcfcf719

View File

@ -65,3 +65,9 @@ GPTNeoForCausalLM
.. autoclass:: transformers.GPTNeoForCausalLM
:members: forward
GPTNeoForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTNeoForSequenceClassification
:members: forward

View File

@ -746,6 +746,7 @@ if is_torch_available():
[
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoModel",
"GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo",
@ -2129,6 +2130,7 @@ if TYPE_CHECKING:
from .models.gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,

View File

@ -145,7 +145,7 @@ from ..funnel.modeling_funnel import (
FunnelModel,
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoModel
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel
from ..ibert.modeling_ibert import (
IBertForMaskedLM,
IBertForMultipleChoice,
@ -632,6 +632,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(DebertaConfig, DebertaForSequenceClassification),
(DebertaV2Config, DebertaV2ForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification),
(GPTNeoConfig, GPTNeoForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification),

View File

@ -28,6 +28,7 @@ if is_torch_available():
_import_structure["modeling_gpt_neo"] = [
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoModel",
"GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo",
@ -41,6 +42,7 @@ if TYPE_CHECKING:
from .modeling_gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,

View File

@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -31,6 +31,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
@ -1027,3 +1028,120 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""
The GPTNeo Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.GPTNeoForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
""",
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTNeoModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.init_weights()
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[range(batch_size), sequence_lengths]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -1603,6 +1603,15 @@ class GPTNeoForCausalLM:
requires_backends(self, ["torch"])
class GPTNeoForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTNeoModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -361,7 +361,6 @@ class GPT2ModelTester:
model = GPT2ForSequenceClassification(config)
model.to(torch_device)
model.eval()
print(config.num_labels, sequence_labels.size())
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

View File

@ -34,6 +34,7 @@ if is_torch_available():
GPT2Tokenizer,
GPTNeoConfig,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
)
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
@ -238,6 +239,16 @@ class GPTNeoModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt_neo_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTNeoForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPTNeoForCausalLM(config)
model.to(torch_device)
@ -274,7 +285,9 @@ class GPTNeoModelTester:
@require_torch
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else ()
all_model_classes = (
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_missing_keys = False
@ -305,6 +318,10 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_gpt_neo_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)