mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Added TF OpenAi GPT1 Sequence Classification (#9105)
* TF OpenAI GPT Sequence Classification * Update src/transformers/models/openai/modeling_tf_openai.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
ef2d4cd445
commit
389aba34bf
@ -138,3 +138,9 @@ TFOpenAIGPTDoubleHeadsModel
|
||||
|
||||
.. autoclass:: transformers.TFOpenAIGPTDoubleHeadsModel
|
||||
:members: call
|
||||
|
||||
TFOpenAIGPTForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFOpenAIGPTForSequenceClassification
|
||||
:members: call
|
||||
|
@ -859,6 +859,7 @@ if is_tf_available():
|
||||
from .models.openai import (
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
TFOpenAIGPTForSequenceClassification,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
TFOpenAIGPTMainLayer,
|
||||
TFOpenAIGPTModel,
|
||||
|
@ -120,7 +120,7 @@ from ..mpnet.modeling_tf_mpnet import (
|
||||
TFMPNetModel,
|
||||
)
|
||||
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
||||
from ..roberta.modeling_tf_roberta import (
|
||||
TFRobertaForMaskedLM,
|
||||
@ -341,6 +341,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(FunnelConfig, TFFunnelForSequenceClassification),
|
||||
(GPT2Config, TFGPT2ForSequenceClassification),
|
||||
(MPNetConfig, TFMPNetForSequenceClassification),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -39,6 +39,7 @@ if is_tf_available():
|
||||
from .modeling_tf_openai import (
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
TFOpenAIGPTForSequenceClassification,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
TFOpenAIGPTMainLayer,
|
||||
TFOpenAIGPTModel,
|
||||
|
@ -28,11 +28,12 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput
|
||||
from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFConv1D,
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
TFSequenceSummary,
|
||||
TFSharedEmbeddings,
|
||||
get_initializer,
|
||||
@ -762,3 +763,154 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
:class:`~transformers.TFOpenAIGPTForSequenceClassification` uses the last token in order to do the classification,
|
||||
as other causal models (e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
|
||||
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
|
||||
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
|
||||
the last value in each row of the batch).
|
||||
""",
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
)
|
||||
class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
self.score = tf.keras.layers.Dense(
|
||||
config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="score",
|
||||
use_bias=False,
|
||||
)
|
||||
self.transformer = TFOpenAIGPTMainLayer(config, name="transformer")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.transformer.tokens_embed
|
||||
|
||||
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="openai-gpt",
|
||||
output_type=TFSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
logits_shape = shape_list(logits)
|
||||
in_logits = None
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if inputs["input_ids"] is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
loss = None
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(
|
||||
tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])
|
||||
)
|
||||
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
@ -1116,6 +1116,15 @@ class TFOpenAIGPTDoubleHeadsModel:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFOpenAIGPTForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFOpenAIGPTLMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
@ -29,6 +29,7 @@ if is_tf_available():
|
||||
from transformers.models.openai.modeling_tf_openai import (
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
TFOpenAIGPTForSequenceClassification,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
TFOpenAIGPTModel,
|
||||
)
|
||||
@ -62,6 +63,7 @@ class TFOpenAIGPTModelTester:
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
self.pad_token_id = self.vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -99,6 +101,7 @@ class TFOpenAIGPTModelTester:
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@ -154,6 +157,21 @@ class TFOpenAIGPTModelTester:
|
||||
)
|
||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_openai_gpt_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
"labels": sequence_labels,
|
||||
}
|
||||
model = TFOpenAIGPTForSequenceClassification(config)
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
@ -177,7 +195,9 @@ class TFOpenAIGPTModelTester:
|
||||
class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
|
||||
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTForSequenceClassification)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||
@ -213,6 +233,10 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
name = model.get_prefix_bias_name()
|
||||
assert name is None
|
||||
|
||||
def test_openai_gpt_sequence_classification_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user