mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-07 14:50:07 +06:00
830 lines
35 KiB
Python
830 lines
35 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" TF 2.0 XXX model. """
|
|
|
|
####################################################
|
|
# In this template, replace all the XXX (various casings) with your model name
|
|
####################################################
|
|
|
|
|
|
import logging
|
|
|
|
import tensorflow as tf
|
|
|
|
from .configuration_xxx import XxxConfig
|
|
from .file_utils import (
|
|
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
|
add_code_sample_docstrings,
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_callable,
|
|
)
|
|
from .modeling_tf_outputs import (
|
|
TFBaseModelOutputWithPooling,
|
|
TFMaskedLMOutput,
|
|
TFMultipleChoiceModelOutput,
|
|
TFQuestionAnsweringModelOutput,
|
|
TFSequenceClassifierOutput,
|
|
TFTokenClassifierOutput,
|
|
)
|
|
from .modeling_tf_utils import (
|
|
TFMaskedLanguageModelingLoss,
|
|
TFMultipleChoiceLoss,
|
|
TFPreTrainedModel,
|
|
TFQuestionAnsweringLoss,
|
|
TFSequenceClassificationLoss,
|
|
TFTokenClassificationLoss,
|
|
get_initializer,
|
|
shape_list,
|
|
)
|
|
from .tokenization_utils import BatchEncoding
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_CONFIG_FOR_DOC = "XXXConfig"
|
|
_TOKENIZER_FOR_DOC = "XxxTokenizer"
|
|
|
|
####################################################
|
|
# This list contrains shortcut names for some of
|
|
# the pretrained weights provided with the models
|
|
####################################################
|
|
TF_XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
"xxx-base-uncased",
|
|
"xxx-large-uncased",
|
|
]
|
|
|
|
|
|
####################################################
|
|
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
|
|
# - tf.keras.layers.Layer for the layers and
|
|
# - TFPreTrainedModel for the models (itself a sub-class of tf.keras.Model)
|
|
####################################################
|
|
|
|
####################################################
|
|
# Here is an example of typical layer in a TF 2.0 model of the library
|
|
# The classes are usually identical to the PyTorch ones and prefixed with 'TF'.
|
|
#
|
|
# Note that class __init__ parameters includes **kwargs (send to 'super').
|
|
# This let us have a control on class scope and variable names:
|
|
# More precisely, we set the names of the class attributes (lower level layers) to
|
|
# to the equivalent attributes names in the PyTorch model so we can have equivalent
|
|
# class and scope structure between PyTorch and TF 2.0 models and easily load one in the other.
|
|
#
|
|
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
|
|
####################################################
|
|
|
|
TFXxxAttention = tf.keras.layers.Layer
|
|
|
|
TFXxxIntermediate = tf.keras.layers.Layer
|
|
|
|
TFXxxOutput = tf.keras.layers.Layer
|
|
|
|
|
|
class TFXxxLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.attention = TFXxxAttention(config, name="attention")
|
|
self.intermediate = TFXxxIntermediate(config, name="intermediate")
|
|
self.transformer_output = TFXxxOutput(config, name="output")
|
|
|
|
def call(self, inputs, training=False):
|
|
hidden_states, attention_mask, head_mask = inputs
|
|
|
|
attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
|
|
attention_output = attention_outputs[0]
|
|
intermediate_output = self.intermediate(attention_output)
|
|
layer_output = self.transformer_output([intermediate_output, attention_output], training=training)
|
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
|
return outputs
|
|
|
|
|
|
####################################################
|
|
# The full model without a specific pretrained or finetuning head is
|
|
# provided as a tf.keras.layers.Layer usually called "TFXxxMainLayer"
|
|
####################################################
|
|
class TFXxxMainLayer(tf.keras.layers.Layer):
|
|
def __init__(self, config, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embeddings.word_embeddings = value
|
|
self.embeddings.vocab_size = value.shape[0]
|
|
|
|
def _prune_heads(self, heads_to_prune):
|
|
raise NotImplementedError # Not implemented yet in the library for TF 2.0 models
|
|
|
|
def call(
|
|
self,
|
|
inputs,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
training=False,
|
|
):
|
|
if isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
|
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
|
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
|
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
|
assert len(inputs) <= 9, "Too many inputs."
|
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
|
input_ids = inputs.get("input_ids")
|
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
|
position_ids = inputs.get("position_ids", position_ids)
|
|
head_mask = inputs.get("head_mask", head_mask)
|
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
|
return_dict = inputs.get("return_dict", return_dict)
|
|
assert len(inputs) <= 9, "Too many inputs."
|
|
else:
|
|
input_ids = inputs
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
|
return_dict = return_dict if return_dict is not None else self.return_dict
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_shape = shape_list(input_ids)
|
|
elif inputs_embeds is not None:
|
|
input_shape = shape_list(inputs_embeds)[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
if attention_mask is None:
|
|
attention_mask = tf.fill(input_shape, 1)
|
|
if token_type_ids is None:
|
|
token_type_ids = tf.fill(input_shape, 0)
|
|
|
|
# We create a 3D attention mask from a 2D tensor mask.
|
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
# this attention mask is more simple than the triangular masking of causal attention
|
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
# masked positions, this operation will create a tensor which is 0.0 for
|
|
# positions we want to attend and -10000.0 for masked positions.
|
|
# Since we are adding it to the raw scores before the softmax, this is
|
|
# effectively the same as removing these entirely.
|
|
|
|
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
if head_mask is not None:
|
|
raise NotImplementedError
|
|
else:
|
|
head_mask = [None] * self.num_hidden_layers
|
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
|
|
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
|
encoder_outputs = self.encoder(
|
|
embedding_output,
|
|
extended_attention_mask,
|
|
head_mask,
|
|
output_attentions,
|
|
output_hidden_states,
|
|
return_dict,
|
|
training=training,
|
|
)
|
|
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = self.pooler(sequence_output)
|
|
|
|
if not return_dict:
|
|
return (
|
|
sequence_output,
|
|
pooled_output,
|
|
) + encoder_outputs[1:]
|
|
|
|
return TFBaseModelOutputWithPooling(
|
|
last_hidden_state=sequence_output,
|
|
pooler_output=pooled_output,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|
|
|
|
|
|
####################################################
|
|
# TFXxxPreTrainedModel is a sub-class of tf.keras.Model
|
|
# which take care of loading and saving pretrained weights
|
|
# and various common utilities.
|
|
# Here you just need to specify a few (self-explanatory)
|
|
# pointers for your model.
|
|
####################################################
|
|
class TFXxxPreTrainedModel(TFPreTrainedModel):
|
|
"""An abstract class to handle weights initialization and
|
|
a simple interface for downloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = XxxConfig
|
|
base_model_prefix = "transformer"
|
|
|
|
|
|
XXX_START_DOCSTRING = r"""
|
|
The XXX model was proposed in
|
|
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding
|
|
<https://arxiv.org/abs/1810.04805>`__ by....
|
|
|
|
This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
|
|
Use it as a regular TF 2.0 Keras Model and
|
|
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
|
|
|
|
.. note::
|
|
|
|
TF 2.0 models accepts two formats as inputs:
|
|
|
|
- having all inputs as keyword arguments (like PyTorch models), or
|
|
- having all inputs as a list, tuple or dict in the first positional arguments.
|
|
|
|
This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
|
|
all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
|
|
|
|
If you choose this second option, there are three possibilities you can use to gather all the input Tensors
|
|
in the first positional argument :
|
|
|
|
- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
|
|
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
|
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
|
|
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
|
:obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.XxxConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
XXX_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using :class:`transformers.XxxTokenizer`.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__
|
|
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`):
|
|
Mask to avoid performing attention on padding token indices.
|
|
Mask values selected in ``[0, 1]``:
|
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
|
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
|
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`):
|
|
Segment token indices to indicate first and second portions of the inputs.
|
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
|
corresponds to a `sentence B` token
|
|
|
|
`What are token type IDs? <../glossary.html#token-type-ids>`__
|
|
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`):
|
|
Indices of positions of each input sequence tokens in the position embeddings.
|
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
|
|
|
`What are position IDs? <../glossary.html#position-ids>`__
|
|
head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
|
inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`):
|
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
|
|
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
|
|
(if set to :obj:`False`) for evaluation.
|
|
output_attentions (:obj:`bool`, `optional`):
|
|
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
|
output_hidden_states (:obj:`bool`, `optional`):
|
|
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
|
return_dict (:obj:`bool`, `optional`):
|
|
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
|
|
plain tuple.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare XXX Model transformer outputing raw hidden-states without any specific head on top.",
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class TFXxxModel(TFXxxPreTrainedModel):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config, *inputs, **kwargs)
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-cased",
|
|
output_type=TFBaseModelOutputWithPooling,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def call(self, inputs, **kwargs):
|
|
outputs = self.transformer(inputs, **kwargs)
|
|
return outputs
|
|
|
|
|
|
TFXxxMLMHead = tf.keras.layers.Layer
|
|
|
|
|
|
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING)
|
|
class TFXxxForMaskedLM(TFXxxPreTrainedModel, TFMaskedLanguageModelingLoss):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config, *inputs, **kwargs)
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name="mlm")
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-cased",
|
|
output_type=TFMaskedLMOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def call(
|
|
self,
|
|
inputs=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
labels=None,
|
|
training=False,
|
|
):
|
|
r"""
|
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the masked language modeling loss.
|
|
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
|
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
|
in ``[0, ..., config.vocab_size]``
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
|
if isinstance(inputs, (tuple, list)):
|
|
labels = inputs[9] if len(inputs) > 9 else labels
|
|
if len(inputs) > 9:
|
|
inputs = inputs[:9]
|
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
|
labels = inputs.pop("labels", labels)
|
|
|
|
outputs = self.transformer(
|
|
inputs,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
training=training,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
prediction_scores = self.mlm(sequence_output, training=training)
|
|
|
|
loss = None if labels is None else self.compute_loss(labels, prediction_scores)
|
|
|
|
if not return_dict:
|
|
output = (prediction_scores,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TFMaskedLMOutput(
|
|
loss=loss,
|
|
logits=prediction_scores,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
|
the pooled output) e.g. for GLUE tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class TFXxxForSequenceClassification(TFXxxPreTrainedModel, TFSequenceClassificationLoss):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(
|
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
|
)
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING)
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-cased",
|
|
output_type=TFSequenceClassifierOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def call(
|
|
self,
|
|
inputs=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
labels=None,
|
|
training=False,
|
|
):
|
|
r"""
|
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for computing the sequence classification/regression loss.
|
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
|
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
|
if isinstance(inputs, (tuple, list)):
|
|
labels = inputs[9] if len(inputs) > 9 else labels
|
|
if len(inputs) > 9:
|
|
inputs = inputs[:9]
|
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
|
labels = inputs.pop("labels", labels)
|
|
|
|
outputs = self.transformer(
|
|
inputs,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
training=training,
|
|
)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output, training=training)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None if labels is None else self.compute_loss(labels, logits)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TFSequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model with a multiple choice classification head on top (a linear layer on top of
|
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class TFXxxForMultipleChoice(TFXxxPreTrainedModel, TFMultipleChoiceLoss):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config, *inputs, **kwargs)
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(
|
|
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
|
)
|
|
|
|
@property
|
|
def dummy_inputs(self):
|
|
"""Dummy inputs to build the network.
|
|
|
|
Returns:
|
|
tf.Tensor with dummy inputs
|
|
"""
|
|
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-cased",
|
|
output_type=TFMultipleChoiceModelOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def call(
|
|
self,
|
|
inputs,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
labels=None,
|
|
training=False,
|
|
):
|
|
r"""
|
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for computing the multiple choice classification loss.
|
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
|
of the input tensors. (see `input_ids` above)s after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
"""
|
|
if isinstance(inputs, (tuple, list)):
|
|
input_ids = inputs[0]
|
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
|
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
|
|
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
|
|
return_dict = inputs[8] if len(inputs) > 8 else return_dict
|
|
labels = inputs[9] if len(inputs) > 9 else labels
|
|
assert len(inputs) <= 10, "Too many inputs."
|
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
|
input_ids = inputs.get("input_ids")
|
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
|
position_ids = inputs.get("position_ids", position_ids)
|
|
head_mask = inputs.get("head_mask", head_mask)
|
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
|
output_attentions = inputs.get("output_attentions", output_attentions)
|
|
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
|
|
return_dict = inputs.get("return_dict", return_dict)
|
|
labels = inputs.get("labels", labels)
|
|
assert len(inputs) <= 10, "Too many inputs."
|
|
else:
|
|
input_ids = inputs
|
|
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
|
|
|
if input_ids is not None:
|
|
num_choices = shape_list(input_ids)[1]
|
|
seq_length = shape_list(input_ids)[2]
|
|
else:
|
|
num_choices = shape_list(inputs_embeds)[1]
|
|
seq_length = shape_list(inputs_embeds)[2]
|
|
|
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
|
flat_inputs_embeds = (
|
|
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
|
if inputs_embeds is not None
|
|
else None
|
|
)
|
|
|
|
flat_inputs = [
|
|
flat_input_ids,
|
|
flat_attention_mask,
|
|
flat_token_type_ids,
|
|
flat_position_ids,
|
|
head_mask,
|
|
flat_inputs_embeds,
|
|
output_attentions,
|
|
output_hidden_states,
|
|
return_dict,
|
|
]
|
|
|
|
outputs = self.transformer(flat_inputs, training=training)
|
|
|
|
pooled_output = outputs[1]
|
|
|
|
pooled_output = self.dropout(pooled_output, training=training)
|
|
logits = self.classifier(pooled_output)
|
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
|
|
|
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
|
|
|
|
if not return_dict:
|
|
output = (reshaped_logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TFMultipleChoiceModelOutput(
|
|
loss=loss,
|
|
logits=reshaped_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model with a token classification head on top (a linear layer on top of
|
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class TFXxxForTokenClassification(TFXxxPreTrainedModel, TFTokenClassificationLoss):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
|
self.classifier = tf.keras.layers.Dense(
|
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
|
)
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING)
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-cased",
|
|
output_type=TFTokenClassifierOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def call(
|
|
self,
|
|
inputs=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
labels=None,
|
|
training=False,
|
|
):
|
|
r"""
|
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
Labels for computing the token classification loss.
|
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
|
if isinstance(inputs, (tuple, list)):
|
|
labels = inputs[9] if len(inputs) > 9 else labels
|
|
if len(inputs) > 9:
|
|
inputs = inputs[:9]
|
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
|
labels = inputs.pop("labels", labels)
|
|
|
|
outputs = self.transformer(
|
|
inputs,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
training=training,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
sequence_output = self.dropout(sequence_output, training=training)
|
|
logits = self.classifier(sequence_output)
|
|
|
|
loss = None if labels is None else self.compute_loss(labels, logits)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TFTokenClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""XXX Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
|
XXX_START_DOCSTRING,
|
|
)
|
|
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel, TFQuestionAnsweringLoss):
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super().__init__(config, *inputs, **kwargs)
|
|
self.num_labels = config.num_labels
|
|
|
|
self.transformer = TFXxxMainLayer(config, name="transformer")
|
|
self.qa_outputs = tf.keras.layers.Dense(
|
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
|
)
|
|
|
|
@add_start_docstrings_to_callable(XXX_INPUTS_DOCSTRING)
|
|
@add_code_sample_docstrings(
|
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
|
checkpoint="xxx-base-cased",
|
|
output_type=TFQuestionAnsweringModelOutput,
|
|
config_class=_CONFIG_FOR_DOC,
|
|
)
|
|
def call(
|
|
self,
|
|
inputs=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
position_ids=None,
|
|
head_mask=None,
|
|
inputs_embeds=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
start_positions=None,
|
|
end_positions=None,
|
|
training=False,
|
|
):
|
|
r"""
|
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
|
Position outside of the sequence are not taken into account for computing the loss.
|
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
|
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
|
Position outside of the sequence are not taken into account for computing the loss.
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.transformer.return_dict
|
|
if isinstance(inputs, (tuple, list)):
|
|
start_positions = inputs[9] if len(inputs) > 9 else start_positions
|
|
end_positions = inputs[10] if len(inputs) > 10 else end_positions
|
|
if len(inputs) > 9:
|
|
inputs = inputs[:9]
|
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
|
start_positions = inputs.pop("start_positions", start_positions)
|
|
end_positions = inputs.pop("end_positions", start_positions)
|
|
|
|
outputs = self.transformer(
|
|
inputs,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
position_ids=position_ids,
|
|
head_mask=head_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
training=training,
|
|
)
|
|
|
|
sequence_output = outputs[0]
|
|
|
|
logits = self.qa_outputs(sequence_output)
|
|
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
|
start_logits = tf.squeeze(start_logits, axis=-1)
|
|
end_logits = tf.squeeze(end_logits, axis=-1)
|
|
|
|
loss = None
|
|
if start_positions is not None and end_positions is not None:
|
|
labels = {"start_position": start_positions}
|
|
labels["end_position"] = end_positions
|
|
loss = self.compute_loss(labels, (start_logits, end_logits))
|
|
|
|
if not return_dict:
|
|
output = (start_logits, end_logits) + outputs[2:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return TFQuestionAnsweringModelOutput(
|
|
loss=loss,
|
|
start_logits=start_logits,
|
|
end_logits=end_logits,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|