Tensorflow improvements (#4530)

* Better None gradients handling

* Apply Style

* Apply Style

* Create a loss class per task to compute its respective loss

* Add loss classes to the ALBERT TF models

* Add loss classes to the BERT TF models

* Add question answering and multiple choice to TF Camembert

* Remove prints

* Add multiple choice model to TF DistilBERT + loss computation

* Add question answering model to TF Electra + loss computation

* Add token classification, question answering and multiple choice models to TF Flaubert

* Add multiple choice model to TF Roberta + loss computation

* Add multiple choice model to TF XLM + loss computation

* Add multiple choice and question answering models to TF XLM-Roberta

* Add multiple choice model to TF XLNet + loss computation

* Remove unused parameters

* Add task loss classes

* Reorder TF imports + add new model classes

* Add new model classes

* Bugfix in TF T5 model

* Bugfix for TF T5 tests

* Bugfix in TF T5 model

* Fix TF T5 model tests

* Fix T5 tests + some renaming

* Fix inheritance issue in the AutoX tests

* Add tests for TF Flaubert and TF XLM Roberta

* Add tests for TF Flaubert and TF XLM Roberta

* Remove unused piece of code in the TF trainer

* bugfix and remove unused code

* Bugfix for TF 2.2

* Apply Style

* Divide TFSequenceClassificationAndMultipleChoiceLoss into their two respective name

* Apply style

* Mirror the PT Trainer in the TF one: fp16, optimizers and tb_writer as class parameter and better dataset handling

* Fix TF optimizations tests and apply style

* Remove useless parameter

* Bugfix and apply style

* Fix TF Trainer prediction

* Now the TF models return the loss such as their PyTorch couterparts

* Apply Style

* Ignore some tests output

* Take into account the SQuAD cls_index, p_mask and is_impossible parameters for the QuestionAnswering task models.

* Fix names for SQuAD data

* Apply Style

* Fix conflicts with 2.11 release

* Fix conflicts with 2.11

* Fix wrongname

* Add better documentation on the new create_optimizer function

* Fix isort

* logging_dir: use same default as PyTorch

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Julien Plu 2020-06-05 01:45:53 +02:00 committed by GitHub
parent ccd26c2862
commit f9414f7553
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 2380 additions and 558 deletions

4
.gitignore vendored
View File

@ -8,6 +8,10 @@ __pycache__/
# C extensions
*.so
# tests and logs
tests/fixtures
logs/
# Distribution / packaging
.Python
build/

View File

@ -352,173 +352,193 @@ if is_torch_available():
# TensorFlow
if is_tf_available():
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
TFSequenceSummary,
shape_list,
tf_top_k_top_p_filtering,
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
)
from .modeling_tf_auto import (
TFAutoModel,
TFAutoModelForPreTraining,
TFAutoModelForMultipleChoice,
TFAutoModelForSequenceClassification,
TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead,
TFAutoModelForTokenClassification,
TF_MODEL_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
from .modeling_tf_bert import (
TFBertPreTrainedModel,
TFBertMainLayer,
TFBertEmbeddings,
TFBertModel,
TFBertForPreTraining,
TFBertForMaskedLM,
TFBertForNextSentencePrediction,
TFBertForSequenceClassification,
TFBertForMultipleChoice,
TFBertForTokenClassification,
TFBertForQuestionAnswering,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_gpt2 import (
TFGPT2PreTrainedModel,
TFGPT2MainLayer,
TFGPT2Model,
TFGPT2LMHeadModel,
TFGPT2DoubleHeadsModel,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_openai import (
TFOpenAIGPTPreTrainedModel,
TFOpenAIGPTMainLayer,
TFOpenAIGPTModel,
TFOpenAIGPTLMHeadModel,
TFOpenAIGPTDoubleHeadsModel,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_transfo_xl import (
TFTransfoXLPreTrainedModel,
TFTransfoXLMainLayer,
TFTransfoXLModel,
TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
)
from .modeling_tf_xlnet import (
TFXLNetPreTrainedModel,
TFXLNetMainLayer,
TFXLNetModel,
TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_xlm import (
TFXLMPreTrainedModel,
TFXLMMainLayer,
TFXLMModel,
TFXLMWithLMHeadModel,
TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_xlm_roberta import (
TFXLMRobertaForMaskedLM,
TFXLMRobertaModel,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_roberta import (
TFRobertaPreTrainedModel,
TFRobertaMainLayer,
TFRobertaModel,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaForQuestionAnswering,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_camembert import (
TFCamembertModel,
TFCamembertForMaskedLM,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_flaubert import (
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
TFFlaubertForSequenceClassification,
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_distilbert import (
TFDistilBertPreTrainedModel,
TFDistilBertMainLayer,
TFDistilBertModel,
TFDistilBertForMaskedLM,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForQuestionAnswering,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
)
from .modeling_tf_ctrl import (
TFCTRLPreTrainedModel,
TFCTRLModel,
TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
TFAutoModelForMultipleChoice,
TFAutoModelForPreTraining,
TFAutoModelForQuestionAnswering,
TFAutoModelForSequenceClassification,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
)
from .modeling_tf_albert import (
TFAlbertPreTrainedModel,
TFAlbertMainLayer,
TFAlbertModel,
TFAlbertForPreTraining,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForMaskedLM,
TFAlbertForMultipleChoice,
TFAlbertForSequenceClassification,
TFAlbertForPreTraining,
TFAlbertForQuestionAnswering,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForSequenceClassification,
TFAlbertForTokenClassification,
TFAlbertMainLayer,
TFAlbertModel,
TFAlbertPreTrainedModel,
)
from .modeling_tf_t5 import (
TFT5PreTrainedModel,
TFT5Model,
TFT5ForConditionalGeneration,
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
from .modeling_tf_bert import (
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFBertEmbeddings,
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertForTokenClassification,
TFBertMainLayer,
TFBertModel,
TFBertPreTrainedModel,
)
from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
TFCamembertModel,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
)
from .modeling_tf_ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLLMHeadModel,
TFCTRLModel,
TFCTRLPreTrainedModel,
)
from .modeling_tf_distilbert import (
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFDistilBertForMaskedLM,
TFDistilBertForMultipleChoice,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertMainLayer,
TFDistilBertModel,
TFDistilBertPreTrainedModel,
)
from .modeling_tf_electra import (
TFElectraPreTrainedModel,
TFElectraModel,
TFElectraForPreTraining,
TFElectraForMaskedLM,
TFElectraForTokenClassification,
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForTokenClassification,
TFElectraModel,
TFElectraPreTrainedModel,
)
from .modeling_tf_flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertWithLMHeadModel,
TFFlaubertModel,
)
from .modeling_tf_gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFGPT2DoubleHeadsModel,
TFGPT2LMHeadModel,
TFGPT2MainLayer,
TFGPT2Model,
TFGPT2PreTrainedModel,
)
from .modeling_tf_openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
TFOpenAIGPTLMHeadModel,
TFOpenAIGPTMainLayer,
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaMainLayer,
TFRobertaModel,
TFRobertaPreTrainedModel,
)
from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5ForConditionalGeneration,
TFT5Model,
TFT5PreTrainedModel,
)
from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
TFTransfoXLLMHeadModel,
TFTransfoXLMainLayer,
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMWithLMHeadModel,
TFXLMMainLayer,
TFXLMModel,
TFXLMPreTrainedModel,
)
from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM,
TFXLMRobertaForMultipleChoice,
TFXLMRobertaForQuestionAnswering,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TFXLMRobertaModel,
)
from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetLMHeadModel,
TFXLNetMainLayer,
TFXLNetModel,
TFXLNetPreTrainedModel,
)
# Optimization
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
from .optimization_tf import (
AdamWeightDecay,
create_optimizer,
GradientAccumulator,
WarmUp,
)
# Trainer
from .trainer_tf import TFTrainer

View File

@ -394,8 +394,8 @@ def squad_convert_examples_to_features(
"qas_id": ex.qas_id,
},
{
"start_position": ex.start_position,
"end_position": ex.end_position,
"start_positions": ex.start_position,
"end_positions": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
"is_impossible": ex.is_impossible,
@ -412,8 +412,8 @@ def squad_convert_examples_to_features(
"qas_id": tf.string,
},
{
"start_position": tf.int64,
"end_position": tf.int64,
"start_positions": tf.int64,
"end_positions": tf.int64,
"cls_index": tf.int64,
"p_mask": tf.int32,
"is_impossible": tf.int32,
@ -429,8 +429,8 @@ def squad_convert_examples_to_features(
"qas_id": tf.TensorShape([]),
},
{
"start_position": tf.TensorShape([]),
"end_position": tf.TensorShape([]),
"start_positions": tf.TensorShape([]),
"end_positions": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]),

View File

@ -81,6 +81,8 @@ class HfArgumentParser(ArgumentParser):
kwargs["type"] = field.type
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
else:
kwargs["required"] = True
self.add_argument(field_name, **kwargs)

View File

@ -23,7 +23,16 @@ import tensorflow as tf
from .configuration_albert import AlbertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -841,7 +850,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """,
ALBERT_START_DOCSTRING,
)
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -852,9 +861,25 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs):
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`)
@ -878,27 +903,126 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertForSequenceClassification.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
outputs = self.albert(inputs, **kwargs)
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # logits, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Albert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
ALBERT_START_DOCSTRING,
)
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.albert = TFAlbertMainLayer(config, name="albert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertForTokenClassification
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertForTokenClassification.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
ALBERT_START_DOCSTRING,
)
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -908,9 +1032,32 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs):
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
@ -938,14 +1085,23 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertForQuestionAnswering.from_pretrained('albert-base-v2')
input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet")
start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
outputs = self.albert(inputs, **kwargs)
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
@ -956,7 +1112,13 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
outputs = (start_logits, end_logits,) + outputs[2:]
return outputs # start_logits, end_logits, (hidden_states), (attentions)
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -964,7 +1126,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ALBERT_START_DOCSTRING,
)
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -992,9 +1154,15 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
@ -1019,12 +1187,13 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
example1 = ["This is a context", "Is it a context? Yes"]
example2 = ["This is a context", "Is it a context? No"]
encoding = tokenizer.batch_encode_plus([example1, example2], return_tensors='tf', truncation_strategy="only_first", pad_to_max_length=True, max_length=128)
outputs = model(encoding["input_ids"][None, :])
logits = outputs[0]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
@ -1036,10 +1205,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
print("isdict(1)")
input_ids = inputs.get("input_ids")
print(input_ids)
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
@ -1080,4 +1246,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # reshaped_logits, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)

View File

@ -22,14 +22,18 @@ from .configuration_auto import (
AlbertConfig,
AutoConfig,
BertConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
ElectraConfig,
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
RobertaConfig,
T5Config,
TransfoXLConfig,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .configuration_utils import PretrainedConfig
@ -39,6 +43,7 @@ from .modeling_tf_albert import (
TFAlbertForPreTraining,
TFAlbertForQuestionAnswering,
TFAlbertForSequenceClassification,
TFAlbertForTokenClassification,
TFAlbertModel,
)
from .modeling_tf_bert import (
@ -50,18 +55,43 @@ from .modeling_tf_bert import (
TFBertForTokenClassification,
TFBertModel,
)
from .modeling_tf_camembert import (
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TFCamembertModel,
)
from .modeling_tf_ctrl import TFCTRLLMHeadModel, TFCTRLModel
from .modeling_tf_distilbert import (
TFDistilBertForMaskedLM,
TFDistilBertForMultipleChoice,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertModel,
)
from .modeling_tf_electra import (
TFElectraForMaskedLM,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForTokenClassification,
TFElectraModel,
)
from .modeling_tf_flaubert import (
TFFlaubertForMultipleChoice,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
)
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from .modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
@ -70,12 +100,23 @@ from .modeling_tf_roberta import (
from .modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from .modeling_tf_transfo_xl import TFTransfoXLLMHeadModel, TFTransfoXLModel
from .modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMModel,
TFXLMWithLMHeadModel,
)
from .modeling_tf_xlm_roberta import (
TFXLMRobertaForMaskedLM,
TFXLMRobertaForMultipleChoice,
TFXLMRobertaForQuestionAnswering,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TFXLMRobertaModel,
)
from .modeling_tf_xlnet import (
TFXLNetForMultipleChoice,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
@ -89,83 +130,118 @@ logger = logging.getLogger(__name__)
TF_MODEL_MAPPING = OrderedDict(
[
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
(RobertaConfig, TFRobertaModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(GPT2Config, TFGPT2Model),
(TransfoXLConfig, TFTransfoXLModel),
(XLNetConfig, TFXLNetModel),
(XLMConfig, TFXLMModel),
(CamembertConfig, TFCamembertModel),
(CTRLConfig, TFCTRLModel),
(DistilBertConfig, TFDistilBertModel),
(ElectraConfig, TFElectraModel),
(FlaubertConfig, TFFlaubertModel),
(GPT2Config, TFGPT2Model),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(RobertaConfig, TFRobertaModel),
(T5Config, TFT5Model),
(TransfoXLConfig, TFTransfoXLModel),
(XLMConfig, TFXLMModel),
(XLMRobertaConfig, TFXLMRobertaModel),
(XLNetConfig, TFXLNetModel),
]
)
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForPreTraining),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CamembertConfig, TFCamembertForMaskedLM),
(CTRLConfig, TFCTRLLMHeadModel),
(DistilBertConfig, TFDistilBertForMaskedLM),
(ElectraConfig, TFElectraForPreTraining),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(RobertaConfig, TFRobertaForMaskedLM),
(T5Config, TFT5ForConditionalGeneration),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(XLNetConfig, TFXLNetLMHeadModel),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CamembertConfig, TFCamembertForMaskedLM),
(CTRLConfig, TFCTRLLMHeadModel),
(DistilBertConfig, TFDistilBertForMaskedLM),
(ElectraConfig, TFElectraForMaskedLM),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(RobertaConfig, TFRobertaForMaskedLM),
(T5Config, TFT5ForConditionalGeneration),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(XLNetConfig, TFXLNetLMHeadModel),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(CamembertConfig, TFCamembertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice),
(FlaubertConfig, TFFlaubertForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(XLNetConfig, TFXLNetForMultipleChoice),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
(AlbertConfig, TFAlbertForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(ElectraConfig, TFElectraForQuestionAnswering),
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
(RobertaConfig, TFRobertaForQuestionAnswering),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
(DistilBertConfig, TFDistilBertForSequenceClassification),
(AlbertConfig, TFAlbertForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification),
(DistilBertConfig, TFDistilBertForSequenceClassification),
(FlaubertConfig, TFFlaubertForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[(BertConfig, TFBertForMultipleChoice), (AlbertConfig, TFAlbertForMultipleChoice)]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
[
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(AlbertConfig, TFAlbertForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
(DistilBertConfig, TFDistilBertForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(AlbertConfig, TFAlbertForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(CamembertConfig, TFCamembertForTokenClassification),
(DistilBertConfig, TFDistilBertForTokenClassification),
(ElectraConfig, TFElectraForTokenClassification),
(FlaubertConfig, TFFlaubertForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),
]
)
@ -632,11 +708,13 @@ class TFAutoModelWithLMHead(object):
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
if isinstance(config, config_class):
# Not using isinstance() here to do not take into account inheritance
if config_class == type(config):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"

View File

@ -23,7 +23,16 @@ import tensorflow as tf
from .configuration_bert import BertConfig
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -880,7 +889,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """,
BERT_START_DOCSTRING,
)
class TFBertForSequenceClassification(TFBertPreTrainedModel):
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -891,9 +900,25 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs):
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
@ -916,21 +941,35 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
outputs = self.bert(inputs, **kwargs)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # logits, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -938,7 +977,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
BERT_START_DOCSTRING,
)
class TFBertForMultipleChoice(TFBertPreTrainedModel):
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -966,9 +1005,15 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
@ -993,15 +1038,14 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='tf', pad_to_max_length=True)
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
# linear classifier on the output is not yet trained
outputs = model(encoding['input_ids'][None, :])
logits = outputs[0]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
@ -1011,7 +1055,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
@ -1053,7 +1097,11 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # reshaped_logits, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -1061,7 +1109,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
BERT_START_DOCSTRING,
)
class TFBertForTokenClassification(TFBertPreTrainedModel):
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -1072,9 +1120,23 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs):
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
@ -1098,20 +1160,33 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs = self.bert(inputs, **kwargs)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # scores, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -1119,7 +1194,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """,
BERT_START_DOCSTRING,
)
class TFBertForQuestionAnswering(TFBertPreTrainedModel):
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -1129,9 +1204,32 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs):
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
@ -1156,18 +1254,24 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
start_scores, end_scores = model(tf.constant(input_ids)[None, :], token_type_ids=tf.constant(token_type_ids)[None, :])
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[tf.math.argmax(tf.squeeze(start_scores)) : tf.math.argmax(tf.squeeze(end_scores))+1])
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
assert answer == "a nice puppet"
"""
outputs = self.bert(inputs, **kwargs)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
@ -1178,4 +1282,10 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
outputs = (start_logits, end_logits,) + outputs[2:]
return outputs # start_logits, end_logits, (hidden_states), (attentions)
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)

View File

@ -22,6 +22,8 @@ from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings
from .modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaModel,
@ -114,3 +116,30 @@ class TFCamembertForTokenClassification(TFRobertaForTokenClassification):
"""
config_class = CamembertConfig
@add_start_docstrings(
"""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
CAMEMBERT_START_DOCSTRING,
)
class TFCamembertForMultipleChoice(TFRobertaForMultipleChoice):
"""
This class overrides :class:`~transformers.TFRobertaForMultipleChoice`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = CamembertConfig
@add_start_docstrings(
"""CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
CAMEMBERT_START_DOCSTRING,
)
class TFCamembertForQuestionAnswering(TFRobertaForQuestionAnswering):
"""
This class overrides :class:`~transformers.TFRobertaForQuestionAnswering`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = CamembertConfig

View File

@ -23,8 +23,18 @@ import numpy as np
import tensorflow as tf
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -399,7 +409,10 @@ class TFTransformer(tf.keras.layers.Layer):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
@keras_serializable
class TFDistilBertMainLayer(tf.keras.layers.Layer):
config_class = DistilBertConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
@ -662,7 +675,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """,
DISTILBERT_START_DOCSTRING,
)
class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -680,8 +693,16 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
@ -705,20 +726,32 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
distilbert_output = self.distilbert(inputs, **kwargs)
distilbert_output = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False)) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + distilbert_output[1:]
return outputs # logits, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -726,7 +759,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
DISTILBERT_START_DOCSTRING,
)
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -738,8 +771,14 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
)
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
@ -762,20 +801,154 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = TFDistilBertForTokenClassification.from_pretrained('distilbert-base-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs = self.distilbert(inputs, **kwargs)
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # scores, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""DistilBert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
DISTILBERT_START_DOCSTRING,
)
class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.pre_classifier = tf.keras.layers.Dense(
config.dim,
kernel_initializer=get_initializer(config.initializer_range),
activation="relu",
name="pre_classifier",
)
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(
self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import DistilBertTokenizer, TFDistilBertForMultipleChoice
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = TFDistilBertForMultipleChoice.from_pretrained('distilbert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
assert len(inputs) <= 4, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 4, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_inputs = [
flat_input_ids,
flat_attention_mask,
head_mask,
inputs_embeds,
]
distilbert_output = self.distilbert(flat_inputs, training=training)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -783,7 +956,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """,
DISTILBERT_START_DOCSTRING,
)
class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -795,8 +968,29 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
@ -821,19 +1015,35 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
start_scores, end_scores = outputs[:2]
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
distilbert_output = self.distilbert(inputs, **kwargs)
distilbert_output = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=kwargs.get("training", False)) # (bs, max_query_len, dim)
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
outputs = (start_logits, end_logits,) + distilbert_output[1:]
return outputs # start_logits, end_logits, (hidden_states), (attentions)
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)

View File

@ -6,7 +6,13 @@ from transformers import ElectraConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
from .modeling_tf_utils import get_initializer, shape_list
from .modeling_tf_utils import (
TFQuestionAnsweringLoss,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -194,6 +200,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
return head_mask
@keras_serializable
class TFElectraMainLayer(TFElectraPreTrainedModel):
config_class = ElectraConfig
@ -557,13 +564,15 @@ Electra model with a token classification head on top.
Both the discriminator and generator may be loaded into this model.""",
ELECTRA_START_DOCSTRING,
)
class TFElectraForTokenClassification(TFElectraPreTrainedModel):
class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.electra = TFElectraMainLayer(config, name="electra")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(
@ -574,9 +583,14 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
@ -599,9 +613,11 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel):
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
model = TFElectraForTokenClassification.from_pretrained('google/electra-small-discriminator')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
discriminator_hidden_states = self.electra(
@ -610,7 +626,106 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel):
discriminator_sequence_output = discriminator_hidden_states[0]
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
logits = self.classifier(discriminator_sequence_output)
output = (logits,)
output += discriminator_hidden_states[1:]
return output # (loss), scores, (hidden_states), (attentions)
outputs = (logits,) + discriminator_hidden_states[1:]
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings(
"""Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
ELECTRA_START_DOCSTRING,
)
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.electra = TFElectraMainLayer(config, name="electra")
self.qa_outputs = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
)
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import ElectraTokenizer, TFElectraForQuestionAnswering
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
model = TFElectraForQuestionAnswering.from_pretrained('google/electra-small-generator')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.qa_outputs(discriminator_sequence_output)
start_logits, end_logits = tf.split(logits, 2, axis=-1)
start_logits = tf.squeeze(start_logits, axis=-1)
end_logits = tf.squeeze(end_logits, axis=-1)
outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:]
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)

View File

@ -22,13 +22,16 @@ import tensorflow as tf
from .configuration_flaubert import FlaubertConfig
from .file_utils import add_start_docstrings
from .modeling_tf_utils import keras_serializable, shape_list
from .modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMMainLayer,
TFXLMModel,
TFXLMWithLMHeadModel,
get_masks,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -112,6 +115,7 @@ class TFFlaubertModel(TFXLMModel):
self.transformer = TFFlaubertMainLayer(config, name="transformer")
@keras_serializable
class TFFlaubertMainLayer(TFXLMMainLayer):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
@ -327,3 +331,38 @@ class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
@add_start_docstrings(
"""Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
FLAUBERT_START_DOCSTRING,
)
class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):
config_class = FlaubertConfig
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
@add_start_docstrings(
"""Flaubert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
FLAUBERT_START_DOCSTRING,
)
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")
@add_start_docstrings(
"""Flaubert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
FLAUBERT_START_DOCSTRING,
)
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFFlaubertMainLayer(config, name="transformer")

View File

@ -29,6 +29,7 @@ from .modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -199,7 +200,10 @@ class TFBlock(tf.keras.layers.Layer):
return outputs # x, (attentions)
@keras_serializable
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
config_class = OpenAIGPTConfig
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states

View File

@ -21,9 +21,18 @@ import logging
import tensorflow as tf
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
from .modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
)
logger = logging.getLogger(__name__)
@ -82,6 +91,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
return super()._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
@keras_serializable
class TFRobertaMainLayer(TFBertMainLayer):
"""
Same as TFBertMainLayer but uses TFRobertaEmbeddings.
@ -337,7 +347,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
on top of the pooled output) e.g. for GLUE tasks. """,
ROBERTA_START_DOCSTRING,
)
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -346,7 +356,17 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
self.classifier = TFRobertaClassificationHead(config, name="classifier")
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
@ -370,20 +390,164 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.constant([1])[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
outputs = self.roberta(inputs, **kwargs)
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output, training=kwargs.get("training", False))
logits = self.classifier(sequence_output, training=training)
outputs = (logits,) + outputs[2:]
return outputs # logits, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ROBERTA_START_DOCSTRING,
)
class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.roberta = TFBertMainLayer(config, name="roberta")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import RobertaTokenizer, TFRobertaForMultipleChoice
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForMultipleChoice.from_pretrained('roberta-base')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
inputs_embeds,
]
outputs = self.roberta(flat_inputs, training=training)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=training)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -391,7 +555,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
ROBERTA_START_DOCSTRING,
)
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -403,8 +567,22 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
)
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
@ -428,27 +606,40 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForTokenClassification.from_pretrained('roberta-base')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs = self.roberta(inputs, **kwargs)
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # scores, (hidden_states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
ROBERTA_START_DOCSTRING,
)
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -459,8 +650,31 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
)
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
@ -488,14 +702,23 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForQuestionAnswering.from_pretrained('roberta-base')
input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet")
start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
outputs = self.roberta(inputs, **kwargs)
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = outputs[0]
@ -506,4 +729,10 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
outputs = (start_logits, end_logits,) + outputs[2:]
return outputs # start_logits, end_logits, (hidden_states), (attentions)
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)

View File

@ -25,7 +25,8 @@ import tensorflow as tf
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
@ -502,7 +503,10 @@ class _NoLayerEmbedTokens(object):
# The full model without a specific pretrained or finetuning head is
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
####################################################
@keras_serializable
class TFT5MainLayer(tf.keras.layers.Layer):
config_class = T5Config
def __init__(self, config, embed_tokens=None, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
@ -548,12 +552,32 @@ class TFT5MainLayer(tf.keras.layers.Layer):
use_cache=False,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
head_mask = inputs[5] if len(inputs) > 5 else head_mask
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("decoder_input_ids")
attention_mask = inputs.get("decoder_attention_mask", attention_mask)
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
if inputs is not None and inputs_embeds is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
elif inputs is not None:
input_shape = shape_list(inputs)
inputs = tf.reshape(inputs, (-1, input_shape[-1]))
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
@ -561,7 +585,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(inputs)
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape

View File

@ -734,7 +734,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
return outputs
class TFTransfoXLLMHead(tf.keras.layers.Layer):
class TFTransfoXLMHead(tf.keras.layers.Layer):
def __init__(self, config, input_embeddings, **kwargs):
super().__init__(**kwargs)
self.vocab_size = config.vocab_size

View File

@ -84,6 +84,7 @@ def keras_serializable(cls):
else:
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
self._transformers_config = config
self._kwargs = kwargs
cls.__init__ = wrapped_init
@ -94,6 +95,7 @@ def keras_serializable(cls):
def get_config(self):
cfg = super(cls, self).get_config()
cfg["transformers_config"] = self._transformers_config.to_dict()
cfg.update(self._kwargs)
return cfg
cls.get_config = get_config
@ -104,6 +106,44 @@ def keras_serializable(cls):
return cls
class TFQuestionAnsweringLoss:
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
start_loss = loss_fn(labels["start_position"], logits[0])
end_loss = loss_fn(labels["end_position"], logits[1])
return (start_loss + end_loss) / 2.0
class TFTokenClassificationLoss:
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
active_loss = tf.reshape(labels, (-1,)) != -1
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
class TFSequenceClassificationLoss:
def compute_loss(self, labels, logits):
if shape_list(logits)[1] == 1:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
return loss_fn(labels, logits)
TFMultipleChoiceLoss = TFSequenceClassificationLoss
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
r""" Base class for all TF models.
@ -1531,6 +1571,16 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
)
super().build(input_shape)
def get_config(self):
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, mode="embedding"):
"""Get token embeddings of inputs.
Args:

View File

@ -24,8 +24,19 @@ import numpy as np
import tensorflow as tf
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSequenceSummary,
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
@ -198,7 +209,10 @@ class TFTransformerFFN(tf.keras.layers.Layer):
return x
@keras_serializable
class TFXLMMainLayer(tf.keras.layers.Layer):
config_class = XLMConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
@ -717,7 +731,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """,
XLM_START_DOCSTRING,
)
class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -726,8 +740,27 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, config.num_labels)`):
@ -751,19 +784,261 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = TFXLMForSequenceClassification.from_pretrained('xlm-mlm-en-2048')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.constant([1])[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
transformer_outputs = self.transformer(inputs, **kwargs)
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
return outputs
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""XLM Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
XLM_START_DOCSTRING,
)
class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(
self,
inputs,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XLMTokenizer, TFXLMForMultipleChoice
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = TFXLMForMultipleChoice.from_pretrained('xlm-mlm-en-2048')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
langs = inputs[2] if len(inputs) > 2 else langs
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
position_ids = inputs[4] if len(inputs) > 4 else position_ids
lengths = inputs[5] if len(inputs) > 5 else lengths
cache = inputs[6] if len(inputs) > 6 else cache
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
langs = inputs.get("langs", langs)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
lengths = inputs.get("lengths", lengths)
cache = inputs.get("cache", cache)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs = [
flat_input_ids,
flat_attention_mask,
langs,
flat_token_type_ids,
flat_position_ids,
lengths,
cache,
head_mask,
inputs_embeds,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings(
"""XLM Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLM_START_DOCSTRING,
)
class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.transformer = TFXLMMainLayer(config, name="transformer")
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XLMTokenizer, TFXLMForTokenClassification
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = TFXLMForTokenClassification.from_pretrained('xlm-mlm-en-2048')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = transformer_outputs[0]
sequence_output = self.dropout(sequence_output, training=training)
logits = self.classifier(sequence_output)
outputs = (logits,) + transformer_outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -771,7 +1046,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_START_DOCSTRING,
)
class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLMMainLayer(config, name="transformer")
@ -780,8 +1055,34 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
)
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
start_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
@ -807,12 +1108,27 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
model = TFXLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
start_scores, end_scores = outputs[:2]
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
transformer_outputs = self.transformer(inputs, **kwargs)
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
sequence_output = transformer_outputs[0]
@ -825,4 +1141,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
1:
] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (hidden_states), (attentions)
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)

View File

@ -22,6 +22,8 @@ from .configuration_xlm_roberta import XLMRobertaConfig
from .file_utils import add_start_docstrings
from .modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaModel,
@ -114,3 +116,30 @@ class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
XLM_ROBERTA_START_DOCSTRING,
)
class TFXLMRobertaForQuestionAnswering(TFRobertaForQuestionAnswering):
"""
This class overrides :class:`~transformers.TFRobertaForQuestionAnsweringSimple`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
XLM_ROBERTA_START_DOCSTRING,
)
class TFXLMRobertaForMultipleChoice(TFRobertaForMultipleChoice):
"""
This class overrides :class:`~transformers.TFRobertaForMultipleChoice`. Please check the
superclass for the appropriate documentation alongside usage examples.
"""
config_class = XLMRobertaConfig

View File

@ -23,11 +23,15 @@ import numpy as np
import tensorflow as tf
from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFMultipleChoiceLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSequenceSummary,
TFSharedEmbeddings,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
shape_list,
@ -938,7 +942,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """,
XLNET_START_DOCSTRING,
)
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -952,8 +956,28 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
)
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
@ -981,12 +1005,24 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
transformer_outputs = self.transformer(inputs, **kwargs)
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
output = transformer_outputs[0]
output = self.sequence_summary(output)
@ -994,7 +1030,159 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, (mems), (hidden states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""XLNET Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
XLNET_START_DOCSTRING,
)
class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name="transformer")
self.sequence_summary = TFSequenceSummary(
config, initializer_range=config.initializer_range, name="sequence_summary"
)
self.logits_proj = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
)
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(
self,
inputs,
token_type_ids=None,
input_mask=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XLNetTokenizer, TFXLNetForMultipleChoice
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = TFXLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
mems = inputs[2] if len(inputs) > 2 else mems
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
mems = inputs.get("mems", mems)
perm_mask = inputs.get("perm_mask", perm_mask)
target_mapping = inputs.get("target_mapping", target_mapping)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
flat_inputs = [
flat_input_ids,
flat_attention_mask,
mems,
perm_mask,
target_mapping,
flat_token_type_ids,
flat_input_mask,
head_mask,
inputs_embeds,
use_cache,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
logits = self.logits_proj(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings(
@ -1002,7 +1190,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLNET_START_DOCSTRING,
)
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
@ -1012,8 +1200,26 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
@ -1041,19 +1247,36 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = TFXLNetForTokenClassification.from_pretrained('xlnet-large-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
transformer_outputs = self.transformer(inputs, **kwargs)
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
training=training,
)
output = transformer_outputs[0]
logits = self.classifier(output)
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, (mems), (hidden states), (attentions)
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
@ -1061,7 +1284,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """,
XLNET_START_DOCSTRING,
)
class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnsweringLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.transformer = TFXLNetMainLayer(config, name="transformer")
@ -1070,8 +1293,35 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
)
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
start_positions=None,
end_positions=None,
cls_index=None,
p_mask=None,
is_impossible=None,
training=False,
):
r"""
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
@ -1103,12 +1353,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = TFXLNetForQuestionAnsweringSimple.from_pretrained('xlnet-base-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
outputs = model(input_ids)
start_scores, end_scores = outputs[:2]
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
start_scores, end_scores = model(input_dict)
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
"""
transformer_outputs = self.transformer(inputs, **kwargs)
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
training=training,
)
sequence_output = transformer_outputs[0]
@ -1121,7 +1386,13 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
1:
] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions)
if start_positions is not None and end_positions is not None:
labels = {"start_position": start_positions}
labels["end_position"] = end_positions
loss = self.compute_loss(labels, outputs[:2])
outputs = (loss,) + outputs
return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of

View File

@ -58,27 +58,41 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps, end_lr=0.0, optimizer_type="adamw"):
def create_optimizer(
init_lr,
num_train_steps,
num_warmup_steps,
min_lr_ratio=0.0,
adam_epsilon=1e-8,
weight_decay_rate=0.0,
include_in_weight_decay=None,
):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=end_lr,
initial_learning_rate=init_lr,
decay_steps=num_train_steps - num_warmup_steps,
end_learning_rate=init_lr * min_lr_ratio,
)
if num_warmup_steps:
lr_schedule = WarmUp(
initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps,
)
if weight_decay_rate > 0.0:
optimizer = AdamWeightDecay(
learning_rate=lr_schedule,
weight_decay_rate=0.01,
weight_decay_rate=weight_decay_rate,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
epsilon=adam_epsilon,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
include_in_weight_decay=include_in_weight_decay,
)
return optimizer
else:
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=adam_epsilon)
# We return the optimizer and the LR scheduler in order to better track the
# evolution of the LR independently of the optimizer.
return optimizer, lr_schedule
class AdamWeightDecay(tf.keras.optimizers.Adam):

View File

@ -3,12 +3,12 @@
import logging
import math
import os
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel, shape_list
from .modeling_tf_utils import TFPreTrainedModel
from .optimization_tf import GradientAccumulator, create_optimizer
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput
from .training_args_tf import TFTrainingArguments
@ -20,13 +20,14 @@ logger = logging.getLogger(__name__)
class TFTrainer:
model: TFPreTrainedModel
args: TFTrainingArguments
# something similar to a PT Dataset.
# This is just temporary before to have
# a framework-agnostic approach for datasets.
train_dataset: Optional[tf.data.Dataset]
eval_dataset: Optional[tf.data.Dataset]
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
prediction_loss_only: bool
tb_writer: Optional[tf.summary.SummaryWriter] = None
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None
global_step: Optional[int] = None
epoch: Optional[float] = None
def __init__(
self,
@ -36,6 +37,8 @@ class TFTrainer:
eval_dataset: Optional[tf.data.Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False,
tb_writer: Optional[tf.summary.SummaryWriter] = None,
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None,
):
self.model = model
self.args = args
@ -43,55 +46,18 @@ class TFTrainer:
self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only
self.optimizers = optimizers
self.gradient_accumulator = GradientAccumulator()
self._setup_training()
if tb_writer is not None:
self.tb_writer = tb_writer
else:
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
def _setup_training(self) -> None:
"""
Setup the different steps to train a model:
- check if all the data are given
- create the proper strategy
- create the features
- prepare the model settings
"""
self._prepare_dataset()
def get_train_tfdataset(self) -> tf.data.Dataset:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
with self.args.strategy.scope():
self._create_optimizer()
_ = self.optimizer.iterations
self._set_loss_and_metric()
self._create_checkpoint_manager()
self._create_summary_writer()
def _set_loss_and_metric(self) -> None:
"""
Create the training loss and metric with their name. Allowed names are those listed
in the Tensorflow documentation and those contained in the transformers library.
"""
try:
self.loss = tf.keras.losses.get(
{
"class_name": self.args.loss_name,
"config": {"from_logits": True, "reduction": tf.keras.losses.Reduction.NONE},
}
)
except TypeError:
self.loss = tf.keras.losses.get(
{"class_name": self.args.loss_name, "config": {"reduction": tf.keras.losses.Reduction.NONE}}
)
def _create_summary_writer(self) -> None:
"""
Create a summary writer to be able to read the logs in Tensorboard.
"""
self.writer = tf.summary.create_file_writer(self.args.logging_dir)
def _prepare_dataset(self) -> None:
"""
Prepare the training, validation and test data.
"""
if self.train_dataset is not None:
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
if self.args.max_steps > 0:
@ -99,7 +65,7 @@ class TFTrainer:
else:
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
self.train_dataset = (
ds = (
self.train_dataset.cache()
.shuffle(self.num_train_examples)
.batch(self.args.train_batch_size)
@ -109,54 +75,44 @@ class TFTrainer:
if self.args.max_steps > 0:
self.train_dataset = self.train_dataset.repeat(-1)
self.train_dataset = self.args.strategy.experimental_distribute_dataset(self.train_dataset)
else:
self.train_steps = 0
return self.args.strategy.experimental_distribute_dataset(ds)
if self.eval_dataset is not None:
self.eval_dataset = (
self.eval_dataset.batch(self.args.eval_batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)
)
self.eval_dataset = self.args.strategy.experimental_distribute_dataset(self.eval_dataset)
def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
def _create_optimizer(self) -> None:
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
ds = eval_dataset.cache().batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
return self.args.strategy.experimental_distribute_dataset(ds)
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
ds = test_dataset.batch(self.args.eval_batch_size)
return self.args.strategy.experimental_distribute_dataset(ds)
def get_optimizers(
self,
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
"""
Create the training optimizer with its name. Allowed names are those listed
in the Tensorflow documentation and those contained in the transformers library.
"""
if self.args.optimizer_name == "adamw":
self.optimizer = create_optimizer(
self.args.learning_rate, self.train_steps, self.args.warmup_steps, self.args.end_lr
)
else:
try:
self.optimizer = tf.keras.optimizers.get(
{
"class_name": self.args.optimizer_name,
"config": {"learning_rate": self.args.learning_rate, "epsilon": self.args.adam_epsilon},
}
)
except TypeError:
# This is for the case where the optimizer is not Adam-like such as SGD
self.optimizer = tf.keras.optimizers.get(
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
)
logger.info("Created an/a {} optimizer".format(self.args.optimizer_name))
Setup the optimizer and the learning rate scheduler.
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
We provide a reasonable default that works well.
If you want to use something else, you can pass a tuple in the Trainer's init,
or override this method in a subclass.
"""
Create a checkpoint manager in order to be able to make the training
fault-tolerant.
Args:
max_to_keep: the maximum number of checkpoints to keep in the checkpoint path.
load_model: if we want to start the training from the latest checkpoint.
"""
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
if self.optimizers is not None:
return self.optimizers
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=max_to_keep)
optimizer, scheduler = create_optimizer(
self.args.learning_rate,
self.train_steps,
self.args.warmup_steps,
adam_epsilon=self.args.adam_epsilon,
weight_decay_rate=self.args.weight_decay,
)
if load_model:
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
return optimizer, scheduler
@tf.function
def _evaluate_steps(self, per_replica_features, per_replica_labels):
@ -182,6 +138,14 @@ class TFTrainer:
def _prediction_loop(
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
Works both with or without labels.
"""
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
logger.info("***** Running %s *****", description)
logger.info(" Batch size = %d", self.args.eval_batch_size)
@ -196,6 +160,12 @@ class TFTrainer:
loss = tf.reduce_mean(loss)
if not prediction_loss_only:
if isinstance(logits, tuple):
logits = logits[0]
if isinstance(labels, tuple):
labels = labels[0]
if self.args.n_gpu > 1:
for val in logits.values:
if preds is None:
@ -240,10 +210,9 @@ class TFTrainer:
"""
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
"""
if eval_dataset is None:
eval_dataset = self.eval_dataset
eval_ds = self.get_eval_tfdataset(eval_dataset)
output = self._prediction_loop(eval_dataset, description="Evaluation")
output = self._prediction_loop(eval_ds, description="Evaluation")
return output.metrics
@ -251,12 +220,25 @@ class TFTrainer:
"""
Train method to train the model.
"""
train_ds = self.get_train_tfdataset()
if self.args.debug:
tf.summary.trace_on(graph=True, profiler=True)
self.gradient_accumulator.reset()
iterations = self.optimizer.iterations
with self.args.strategy.scope():
optimizer, lr_scheduler = self.get_optimizers()
iterations = optimizer.iterations
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=5)
if self.model.ckpt_manager.latest_checkpoint:
logger.info(
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
)
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
if iterations.numpy() > 0:
logger.info("Start the training from the last checkpoint")
@ -268,21 +250,30 @@ class TFTrainer:
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
if self.args.fp16:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
tf.keras.mixed_precision.experimental.set_policy(policy)
with self.tb_writer.as_default():
tf.summary.text("args", self.args.to_json_string())
self.tb_writer.flush()
logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_train_examples)
logger.info(" Num Epochs = %d", epochs)
logger.info(" Total optimization steps = %d", self.train_steps)
for epoch in range(start_epoch, int(epochs + 1)):
for training_loss in self._training_steps():
for training_loss in self._training_steps(train_ds, optimizer):
step = iterations.numpy()
if self.args.debug:
with self.writer.as_default():
with self.tb_writer.as_default():
tf.summary.scalar("loss", training_loss, step=step)
if step == 1 and self.args.debug:
with self.writer.as_default():
with self.tb_writer.as_default():
tf.summary.trace_export(name="training", step=step, profiler_outdir=self.args.logging_dir)
if self.args.evaluate_during_training and step % self.args.eval_steps == 0:
@ -293,17 +284,16 @@ class TFTrainer:
eval_key = "eval_{}".format(key)
logs[eval_key] = value
if callable(self.optimizer.learning_rate):
logs["learning_rate"] = self.optimizer.learning_rate(step).numpy()
else:
logs["learning_rate"] = self.optimizer.learning_rate.numpy()
logs["learning_rate"] = lr_scheduler(step).numpy()
logger.info("Epoch {} Step {} Validation Metrics {}".format(epoch, step, logs))
with self.writer.as_default():
with self.tb_writer.as_default():
for k, v in logs.items():
tf.summary.scalar(k, v, step=step)
self.tb_writer.flush()
if step % self.args.logging_steps == 0:
logger.info("Epoch {} Step {} Train Loss {:.4f}".format(epoch, step, training_loss.numpy()))
@ -314,21 +304,21 @@ class TFTrainer:
if step % self.train_steps == 0:
break
def _training_steps(self):
def _training_steps(self, ds, optimizer):
"""
Returns a generator over training steps (i.e. parameters update).
"""
for i, loss in enumerate(self._accumulate_next_gradients()):
for i, loss in enumerate(self._accumulate_next_gradients(ds)):
if i % self.args.gradient_accumulation_steps == 0:
self._apply_gradients()
self._apply_gradients(optimizer)
yield loss
@tf.function
def _apply_gradients(self):
def _apply_gradients(self, optimizer):
"""Applies the gradients (cross-replica)."""
self.args.strategy.experimental_run_v2(self._step)
self.args.strategy.experimental_run_v2(self._step, args=(optimizer,))
def _step(self):
def _step(self, optimizer):
"""Applies gradients and resets accumulation."""
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
gradients = [
@ -336,12 +326,12 @@ class TFTrainer:
]
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
self.gradient_accumulator.reset()
def _accumulate_next_gradients(self):
def _accumulate_next_gradients(self, ds):
"""Accumulates the gradients from the next element in dataset."""
iterator = iter(self.train_dataset)
iterator = iter(ds)
@tf.function
def _accumulate_next():
@ -388,23 +378,10 @@ class TFTrainer:
labels: the batched labels.
training: run the model in training mode or not
"""
if self.args.mode == "text-classification" or self.args.mode == "token-classification":
logits = self.model(features, training=training)[0]
if isinstance(labels, (dict)):
loss, logits = self.model(features, training=training, **labels)[:2]
else:
logits = self.model(features, training=training)
if self.args.mode == "token-classification":
active_loss = tf.reshape(labels, (-1,)) != -1
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
loss = self.loss(labels, reduced_logits)
elif self.args.mode == "question-answering":
start_loss = self.loss(labels["start_position"], logits[0])
end_loss = self.loss(labels["end_position"], logits[1])
loss = (start_loss + end_loss) / 2.0
else:
loss = self.loss(labels, logits)
loss, logits = self.model(features, labels=labels, training=training)[:2]
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
return loss, logits
@ -418,19 +395,24 @@ class TFTrainer:
test_dataset: something similar to a PT Dataset. This is just
temporary before to have a framework-agnostic approach for datasets.
"""
test_dataset = test_dataset.batch(self.args.eval_batch_size)
test_dataset = self.args.strategy.experimental_distribute_dataset(test_dataset)
test_ds = self.get_test_tfdataset(test_dataset)
return self._prediction_loop(test_dataset, description="Prediction")
return self._prediction_loop(test_ds, description="Prediction")
def save_model(self) -> None:
def save_model(self, output_dir: Optional[str] = None):
"""
Save the pretrained model and create a Tensorflow saved model.
"""
logger.info("Saving model in {}".format(self.args.output_dir))
output_dir = output_dir if output_dir is not None else self.args.output_dir
logger.info("Saving model in {}".format(output_dir))
path = os.path.join(self.args.output_dir, "saved_model")
logger.info("Saving model in {}".format(path))
os.makedirs(path, exist_ok=True)
if not isinstance(self.model, TFPreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self.model.save_pretrained(self.args.output_dir)

View File

@ -1,6 +1,7 @@
import dataclasses
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
@ -27,6 +28,17 @@ def is_tpu_available():
logger = logging.getLogger(__name__)
def default_logdir() -> str:
"""
Same default as PyTorch
"""
import socket
from datetime import datetime
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
return os.path.join("runs", current_time + "_" + socket.gethostname())
@dataclass
class TrainingArguments:
"""
@ -97,7 +109,7 @@ class TrainingArguments:
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})

View File

@ -14,28 +14,9 @@ if is_tf_available():
@dataclass
class TFTrainingArguments(TrainingArguments):
optimizer_name: str = field(
default="adam",
metadata={
"help": 'Name of a Tensorflow optimizer among "adadelta, adagrad, adam, adamax, ftrl, nadam, rmsprop, sgd, adamw"'
},
)
mode: str = field(
default="text-classification",
metadata={"help": 'Type of task, one of "text-classification", "token-classification", "question-answering"'},
)
loss_name: str = field(
default="SparseCategoricalCrossentropy",
metadata={
"help": "Name of a Tensorflow loss. For the list see: https://www.tensorflow.org/api_docs/python/tf/keras/losses"
},
)
tpu_name: str = field(
default=None, metadata={"help": "Name of TPU"},
)
end_lr: float = field(
default=0, metadata={"help": "End learning rate for optimizer"},
)
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
debug: bool = field(
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}

View File

@ -30,7 +30,7 @@ if is_tf_available():
import tensorflow as tf
import numpy as np
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding, TFSharedEmbeddings
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
@ -107,16 +107,32 @@ class TFModelTesterMixin:
and getattr(module_member, "_keras_serializable", False)
)
for main_layer_class in tf_main_layer_classes:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if "T5" in main_layer_class.__name__:
# Take the same values than in TFT5ModelTester for this shared layer
shared = TFSharedEmbeddings(99, 32, name="shared")
main_layer = main_layer_class(config, embed_tokens=shared)
else:
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
}
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
outputs = model(inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "keras_model.h5")
model.save(filepath)
if "T5" in main_layer_class.__name__:
model = tf.keras.models.load_model(
filepath,
custom_objects={
main_layer_class.__name__: main_layer_class,
"TFSharedEmbeddings": TFSharedEmbeddings,
},
)
else:
model = tf.keras.models.load_model(
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
)
@ -126,6 +142,9 @@ class TFModelTesterMixin:
def assert_outputs_same(self, after_outputs, outputs):
# Make sure we don't have nans
if isinstance(after_outputs, tf.Tensor):
out_1 = after_outputs.numpy()
else:
out_1 = after_outputs[0].numpy()
out_2 = outputs[0].numpy()
self.assertEqual(out_1.shape, out_2.shape)
@ -269,7 +288,6 @@ class TFModelTesterMixin:
inputs_keywords = copy.deepcopy(inputs_dict)
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "inputs", None,)
outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy()

View File

@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from .utils import require_tf, slow
if is_tf_available():
import tensorflow as tf
import numpy as np
from transformers import TFFlaubertModel
@require_tf
class TFFlaubertModelIntegrationTest(unittest.TestCase):
@slow
def test_output_embeds_base_model(self):
model = TFFlaubertModel.from_pretrained("jplu/tf-flaubert-small-cased")
input_ids = tf.convert_to_tensor(
[[0, 158, 735, 2592, 1424, 6727, 82, 1]], dtype=tf.int32,
) # "J'aime flaubert !"
output = model(input_ids)[0]
expected_shape = tf.TensorShape((1, 8, 512))
self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice.
expected_slice = tf.convert_to_tensor(
[
[
[-1.8768773, -1.566555, 0.27072418],
[-1.6920038, -0.5873505, 1.9329599],
[-2.9563985, -1.6993835, 1.7972052],
]
],
dtype=tf.float32,
)
self.assertTrue(np.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))

View File

@ -0,0 +1,55 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from .utils import require_tf, slow
if is_tf_available():
import tensorflow as tf
import numpy as np
from transformers import TFXLMRobertaModel
@require_tf
class TFFlaubertModelIntegrationTest(unittest.TestCase):
@slow
def test_output_embeds_base_model(self):
model = TFXLMRobertaModel.from_pretrained("jplu/tf-xlm-roberta-base")
features = {
"input_ids": tf.convert_to_tensor([[0, 2646, 10269, 83, 99942, 2]], dtype=tf.int32), # "My dog is cute"
"attention_mask": tf.convert_to_tensor([[1, 1, 1, 1, 1, 1]], dtype=tf.int32),
}
output = model(features)[0]
expected_shape = tf.TensorShape((1, 6, 768))
self.assertEqual(output.shape, expected_shape)
# compare the actual values for a slice.
expected_slice = tf.convert_to_tensor(
[
[
[0.0681762, 0.10894451, 0.06772504],
[-0.06423668, 0.02366615, 0.04329344],
[-0.06057295, 0.09974135, -0.00070584],
]
],
dtype=tf.float32,
)
self.assertTrue(np.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))

View File

@ -47,7 +47,7 @@ class OptimizationFTest(unittest.TestCase):
with strategy.scope():
accumulator = GradientAccumulator()
variable = tf.Variable([4.0, 3.0])
optimizer = create_optimizer(5e-5, 10, 5)
optimizer, _ = create_optimizer(5e-5, 10, 5)
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
def accumulate_on_replica(gradient):