mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add TFGPT2ForSequenceClassification based on DialogRPT (#8714)
* Add TFGPT2ForSequenceClassification based on DialogRPT * Add TFGPT2ForSequenceClassification based on DialogRPT * TFGPT2ForSequenceClassification based on DialogRPT-refactored code, implemented review comments and added input processing * Add TFGPT2ForSequenceClassification based on DialogRPT * TFGPT2ForSequenceClassification based on DialogRPT-refactored code, implemented review comments and added input processing * code refactor for latest other TF PR * code refactor * code refactor * Update modeling_tf_gpt2.py
This commit is contained in:
parent
28c77ddf3b
commit
483e13273f
@ -114,3 +114,15 @@ TFGPT2DoubleHeadsModel
|
||||
|
||||
.. autoclass:: transformers.TFGPT2DoubleHeadsModel
|
||||
:members: call
|
||||
|
||||
TFGPT2ForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFGPT2ForSequenceClassification
|
||||
:members: call
|
||||
|
||||
TFSequenceClassifierOutputWithPast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_tf_outputs.TFSequenceClassifierOutputWithPast
|
||||
:members:
|
||||
|
@ -775,6 +775,7 @@ if is_tf_available():
|
||||
from .models.gpt2 import (
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFGPT2DoubleHeadsModel,
|
||||
TFGPT2ForSequenceClassification,
|
||||
TFGPT2LMHeadModel,
|
||||
TFGPT2MainLayer,
|
||||
TFGPT2Model,
|
||||
|
@ -557,3 +557,39 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
encoder_last_hidden_state: Optional[tf.Tensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFSequenceClassifierOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of sentence classification models.
|
||||
|
||||
Args:
|
||||
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
|
||||
num_heads, sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||
``past_key_values`` input) to speed up sequential decoding.
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[tf.Tensor] = None
|
||||
logits: tf.Tensor = None
|
||||
past_key_values: Optional[List[tf.Tensor]] = None
|
||||
hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
|
@ -89,7 +89,7 @@ from ..funnel.modeling_tf_funnel import (
|
||||
TFFunnelForTokenClassification,
|
||||
TFFunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
|
||||
from ..longformer.modeling_tf_longformer import (
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
@ -326,6 +326,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(XLMConfig, TFXLMForSequenceClassification),
|
||||
(ElectraConfig, TFElectraForSequenceClassification),
|
||||
(FunnelConfig, TFFunnelForSequenceClassification),
|
||||
(GPT2Config, TFGPT2ForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -25,6 +25,7 @@ if is_tf_available():
|
||||
from .modeling_tf_gpt2 import (
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFGPT2DoubleHeadsModel,
|
||||
TFGPT2ForSequenceClassification,
|
||||
TFGPT2LMHeadModel,
|
||||
TFGPT2MainLayer,
|
||||
TFGPT2Model,
|
||||
|
@ -28,11 +28,16 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutputWithPast,
|
||||
TFCausalLMOutputWithPast,
|
||||
TFSequenceClassifierOutputWithPast,
|
||||
)
|
||||
from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFConv1D,
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
@ -853,3 +858,158 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT2 Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
:class:`~transformers.TFGPT2ForSequenceClassification` uses the last token in order to do the classification, as
|
||||
other causal models (e.g. GPT-1) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
|
||||
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
||||
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
|
||||
the last value in each row of the batch).
|
||||
""",
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
self.score = tf.keras.layers.Dense(
|
||||
config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="score",
|
||||
use_bias=False,
|
||||
)
|
||||
self.transformer = TFGPT2MainLayer(config, name="transformer")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.wte
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="microsoft/DialogRPT-updown",
|
||||
output_type=TFSequenceClassifierOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
past=inputs["past"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if inputs["input_ids"] is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
@ -768,6 +768,15 @@ class TFGPT2DoubleHeadsModel:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFGPT2ForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFGPT2LMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
@ -29,6 +29,7 @@ if is_tf_available():
|
||||
from transformers.models.gpt2.modeling_tf_gpt2 import (
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFGPT2DoubleHeadsModel,
|
||||
TFGPT2ForSequenceClassification,
|
||||
TFGPT2LMHeadModel,
|
||||
TFGPT2Model,
|
||||
shape_list,
|
||||
@ -65,6 +66,7 @@ class TFGPT2ModelTester:
|
||||
self.scope = None
|
||||
self.bos_token_id = self.vocab_size - 1
|
||||
self.eos_token_id = self.vocab_size - 1
|
||||
self.pad_token_id = self.vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -104,6 +106,8 @@ class TFGPT2ModelTester:
|
||||
# initializer_range=self.initializer_range
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@ -271,6 +275,21 @@ class TFGPT2ModelTester:
|
||||
)
|
||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_gpt2_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
"labels": sequence_labels,
|
||||
}
|
||||
model = TFGPT2ForSequenceClassification(config)
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
@ -297,7 +316,11 @@ class TFGPT2ModelTester:
|
||||
@require_tf
|
||||
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||
all_model_classes = (
|
||||
(TFGPT2Model, TFGPT2LMHeadModel, TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
@ -331,6 +354,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
|
||||
|
||||
def test_gpt2_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user