Add TFGPT2ForSequenceClassification based on DialogRPT (#8714)

* Add TFGPT2ForSequenceClassification based on DialogRPT

* Add TFGPT2ForSequenceClassification based on DialogRPT

* TFGPT2ForSequenceClassification based on DialogRPT-refactored code, implemented review comments and added input processing

* Add TFGPT2ForSequenceClassification based on DialogRPT

* TFGPT2ForSequenceClassification based on DialogRPT-refactored code, implemented review comments and added input processing

* code refactor for latest other TF PR

* code refactor

* code refactor

* Update modeling_tf_gpt2.py
This commit is contained in:
sandip 2020-12-07 21:28:37 +05:30 committed by GitHub
parent 28c77ddf3b
commit 483e13273f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 250 additions and 3 deletions

View File

@ -114,3 +114,15 @@ TFGPT2DoubleHeadsModel
.. autoclass:: transformers.TFGPT2DoubleHeadsModel
:members: call
TFGPT2ForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFGPT2ForSequenceClassification
:members: call
TFSequenceClassifierOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFSequenceClassifierOutputWithPast
:members:

View File

@ -775,6 +775,7 @@ if is_tf_available():
from .models.gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGPT2DoubleHeadsModel,
TFGPT2ForSequenceClassification,
TFGPT2LMHeadModel,
TFGPT2MainLayer,
TFGPT2Model,

View File

@ -557,3 +557,39 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFSequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None

View File

@ -89,7 +89,7 @@ from ..funnel.modeling_tf_funnel import (
TFFunnelForTokenClassification,
TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
@ -326,6 +326,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(XLMConfig, TFXLMForSequenceClassification),
(ElectraConfig, TFElectraForSequenceClassification),
(FunnelConfig, TFFunnelForSequenceClassification),
(GPT2Config, TFGPT2ForSequenceClassification),
]
)

View File

@ -25,6 +25,7 @@ if is_tf_available():
from .modeling_tf_gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGPT2DoubleHeadsModel,
TFGPT2ForSequenceClassification,
TFGPT2LMHeadModel,
TFGPT2MainLayer,
TFGPT2Model,

View File

@ -28,11 +28,16 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
from ...modeling_tf_outputs import (
TFBaseModelOutputWithPast,
TFCausalLMOutputWithPast,
TFSequenceClassifierOutputWithPast,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFConv1D,
TFPreTrainedModel,
TFSequenceClassificationLoss,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
@ -853,3 +858,158 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
"""
The GPT2 Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.TFGPT2ForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
""",
GPT2_START_DOCSTRING,
)
class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.score = tf.keras.layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="score",
use_bias=False,
)
self.transformer = TFGPT2MainLayer(config, name="transformer")
def get_output_embeddings(self):
return self.transformer.wte
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="microsoft/DialogRPT-updown",
output_type=TFSequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
past=inputs["past"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if inputs["input_ids"] is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
-1,
keepdims=False,
)
- 1
)
def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)
result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None
if inputs["labels"] is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
else:
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits
if not inputs["return_dict"]:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -768,6 +768,15 @@ class TFGPT2DoubleHeadsModel:
requires_tf(self)
class TFGPT2ForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFGPT2LMHeadModel:
def __init__(self, *args, **kwargs):
requires_tf(self)

View File

@ -29,6 +29,7 @@ if is_tf_available():
from transformers.models.gpt2.modeling_tf_gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGPT2DoubleHeadsModel,
TFGPT2ForSequenceClassification,
TFGPT2LMHeadModel,
TFGPT2Model,
shape_list,
@ -65,6 +66,7 @@ class TFGPT2ModelTester:
self.scope = None
self.bos_token_id = self.vocab_size - 1
self.eos_token_id = self.vocab_size - 1
self.pad_token_id = self.vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -104,6 +106,8 @@ class TFGPT2ModelTester:
# initializer_range=self.initializer_range
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
return_dict=True,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@ -271,6 +275,21 @@ class TFGPT2ModelTester:
)
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
def create_and_check_gpt2_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"labels": sequence_labels,
}
model = TFGPT2ForSequenceClassification(config)
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@ -297,7 +316,11 @@ class TFGPT2ModelTester:
@require_tf
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
all_model_classes = (
(TFGPT2Model, TFGPT2LMHeadModel, TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel)
if is_tf_available()
else ()
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
def setUp(self):
@ -331,6 +354,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
def test_gpt2_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: