mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Tensorflow improvements (#4530)
* Better None gradients handling * Apply Style * Apply Style * Create a loss class per task to compute its respective loss * Add loss classes to the ALBERT TF models * Add loss classes to the BERT TF models * Add question answering and multiple choice to TF Camembert * Remove prints * Add multiple choice model to TF DistilBERT + loss computation * Add question answering model to TF Electra + loss computation * Add token classification, question answering and multiple choice models to TF Flaubert * Add multiple choice model to TF Roberta + loss computation * Add multiple choice model to TF XLM + loss computation * Add multiple choice and question answering models to TF XLM-Roberta * Add multiple choice model to TF XLNet + loss computation * Remove unused parameters * Add task loss classes * Reorder TF imports + add new model classes * Add new model classes * Bugfix in TF T5 model * Bugfix for TF T5 tests * Bugfix in TF T5 model * Fix TF T5 model tests * Fix T5 tests + some renaming * Fix inheritance issue in the AutoX tests * Add tests for TF Flaubert and TF XLM Roberta * Add tests for TF Flaubert and TF XLM Roberta * Remove unused piece of code in the TF trainer * bugfix and remove unused code * Bugfix for TF 2.2 * Apply Style * Divide TFSequenceClassificationAndMultipleChoiceLoss into their two respective name * Apply style * Mirror the PT Trainer in the TF one: fp16, optimizers and tb_writer as class parameter and better dataset handling * Fix TF optimizations tests and apply style * Remove useless parameter * Bugfix and apply style * Fix TF Trainer prediction * Now the TF models return the loss such as their PyTorch couterparts * Apply Style * Ignore some tests output * Take into account the SQuAD cls_index, p_mask and is_impossible parameters for the QuestionAnswering task models. * Fix names for SQuAD data * Apply Style * Fix conflicts with 2.11 release * Fix conflicts with 2.11 * Fix wrongname * Add better documentation on the new create_optimizer function * Fix isort * logging_dir: use same default as PyTorch Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
ccd26c2862
commit
f9414f7553
4
.gitignore
vendored
4
.gitignore
vendored
@ -8,6 +8,10 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
# tests and logs
|
||||||
|
tests/fixtures
|
||||||
|
logs/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
|
@ -352,173 +352,193 @@ if is_torch_available():
|
|||||||
# TensorFlow
|
# TensorFlow
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
TFPreTrainedModel,
|
|
||||||
TFSharedEmbeddings,
|
|
||||||
TFSequenceSummary,
|
|
||||||
shape_list,
|
shape_list,
|
||||||
tf_top_k_top_p_filtering,
|
tf_top_k_top_p_filtering,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFSequenceSummary,
|
||||||
|
TFSharedEmbeddings,
|
||||||
)
|
)
|
||||||
from .modeling_tf_auto import (
|
from .modeling_tf_auto import (
|
||||||
TFAutoModel,
|
|
||||||
TFAutoModelForPreTraining,
|
|
||||||
TFAutoModelForMultipleChoice,
|
|
||||||
TFAutoModelForSequenceClassification,
|
|
||||||
TFAutoModelForQuestionAnswering,
|
|
||||||
TFAutoModelWithLMHead,
|
|
||||||
TFAutoModelForTokenClassification,
|
|
||||||
TF_MODEL_MAPPING,
|
TF_MODEL_MAPPING,
|
||||||
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||||
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
|
||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
)
|
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
|
TFAutoModel,
|
||||||
from .modeling_tf_bert import (
|
TFAutoModelForMultipleChoice,
|
||||||
TFBertPreTrainedModel,
|
TFAutoModelForPreTraining,
|
||||||
TFBertMainLayer,
|
TFAutoModelForQuestionAnswering,
|
||||||
TFBertEmbeddings,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertModel,
|
TFAutoModelForTokenClassification,
|
||||||
TFBertForPreTraining,
|
TFAutoModelWithLMHead,
|
||||||
TFBertForMaskedLM,
|
|
||||||
TFBertForNextSentencePrediction,
|
|
||||||
TFBertForSequenceClassification,
|
|
||||||
TFBertForMultipleChoice,
|
|
||||||
TFBertForTokenClassification,
|
|
||||||
TFBertForQuestionAnswering,
|
|
||||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_gpt2 import (
|
|
||||||
TFGPT2PreTrainedModel,
|
|
||||||
TFGPT2MainLayer,
|
|
||||||
TFGPT2Model,
|
|
||||||
TFGPT2LMHeadModel,
|
|
||||||
TFGPT2DoubleHeadsModel,
|
|
||||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_openai import (
|
|
||||||
TFOpenAIGPTPreTrainedModel,
|
|
||||||
TFOpenAIGPTMainLayer,
|
|
||||||
TFOpenAIGPTModel,
|
|
||||||
TFOpenAIGPTLMHeadModel,
|
|
||||||
TFOpenAIGPTDoubleHeadsModel,
|
|
||||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_transfo_xl import (
|
|
||||||
TFTransfoXLPreTrainedModel,
|
|
||||||
TFTransfoXLMainLayer,
|
|
||||||
TFTransfoXLModel,
|
|
||||||
TFTransfoXLLMHeadModel,
|
|
||||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
TFAdaptiveEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_xlnet import (
|
|
||||||
TFXLNetPreTrainedModel,
|
|
||||||
TFXLNetMainLayer,
|
|
||||||
TFXLNetModel,
|
|
||||||
TFXLNetLMHeadModel,
|
|
||||||
TFXLNetForSequenceClassification,
|
|
||||||
TFXLNetForTokenClassification,
|
|
||||||
TFXLNetForQuestionAnsweringSimple,
|
|
||||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_xlm import (
|
|
||||||
TFXLMPreTrainedModel,
|
|
||||||
TFXLMMainLayer,
|
|
||||||
TFXLMModel,
|
|
||||||
TFXLMWithLMHeadModel,
|
|
||||||
TFXLMForSequenceClassification,
|
|
||||||
TFXLMForQuestionAnsweringSimple,
|
|
||||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_xlm_roberta import (
|
|
||||||
TFXLMRobertaForMaskedLM,
|
|
||||||
TFXLMRobertaModel,
|
|
||||||
TFXLMRobertaForSequenceClassification,
|
|
||||||
TFXLMRobertaForTokenClassification,
|
|
||||||
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_roberta import (
|
|
||||||
TFRobertaPreTrainedModel,
|
|
||||||
TFRobertaMainLayer,
|
|
||||||
TFRobertaModel,
|
|
||||||
TFRobertaForMaskedLM,
|
|
||||||
TFRobertaForSequenceClassification,
|
|
||||||
TFRobertaForTokenClassification,
|
|
||||||
TFRobertaForQuestionAnswering,
|
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_camembert import (
|
|
||||||
TFCamembertModel,
|
|
||||||
TFCamembertForMaskedLM,
|
|
||||||
TFCamembertForSequenceClassification,
|
|
||||||
TFCamembertForTokenClassification,
|
|
||||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_flaubert import (
|
|
||||||
TFFlaubertModel,
|
|
||||||
TFFlaubertWithLMHeadModel,
|
|
||||||
TFFlaubertForSequenceClassification,
|
|
||||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_distilbert import (
|
|
||||||
TFDistilBertPreTrainedModel,
|
|
||||||
TFDistilBertMainLayer,
|
|
||||||
TFDistilBertModel,
|
|
||||||
TFDistilBertForMaskedLM,
|
|
||||||
TFDistilBertForSequenceClassification,
|
|
||||||
TFDistilBertForTokenClassification,
|
|
||||||
TFDistilBertForQuestionAnswering,
|
|
||||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_tf_ctrl import (
|
|
||||||
TFCTRLPreTrainedModel,
|
|
||||||
TFCTRLModel,
|
|
||||||
TFCTRLLMHeadModel,
|
|
||||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_tf_albert import (
|
from .modeling_tf_albert import (
|
||||||
TFAlbertPreTrainedModel,
|
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFAlbertMainLayer,
|
|
||||||
TFAlbertModel,
|
|
||||||
TFAlbertForPreTraining,
|
|
||||||
TFAlbertForMaskedLM,
|
TFAlbertForMaskedLM,
|
||||||
TFAlbertForMultipleChoice,
|
TFAlbertForMultipleChoice,
|
||||||
TFAlbertForSequenceClassification,
|
TFAlbertForPreTraining,
|
||||||
TFAlbertForQuestionAnswering,
|
TFAlbertForQuestionAnswering,
|
||||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TFAlbertForSequenceClassification,
|
||||||
|
TFAlbertForTokenClassification,
|
||||||
|
TFAlbertMainLayer,
|
||||||
|
TFAlbertModel,
|
||||||
|
TFAlbertPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_tf_t5 import (
|
from .modeling_tf_bert import (
|
||||||
TFT5PreTrainedModel,
|
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFT5Model,
|
TFBertEmbeddings,
|
||||||
TFT5ForConditionalGeneration,
|
TFBertForMaskedLM,
|
||||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TFBertForMultipleChoice,
|
||||||
|
TFBertForNextSentencePrediction,
|
||||||
|
TFBertForPreTraining,
|
||||||
|
TFBertForQuestionAnswering,
|
||||||
|
TFBertForSequenceClassification,
|
||||||
|
TFBertForTokenClassification,
|
||||||
|
TFBertMainLayer,
|
||||||
|
TFBertModel,
|
||||||
|
TFBertPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_camembert import (
|
||||||
|
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFCamembertForMaskedLM,
|
||||||
|
TFCamembertModel,
|
||||||
|
TFCamembertForMultipleChoice,
|
||||||
|
TFCamembertForQuestionAnswering,
|
||||||
|
TFCamembertForSequenceClassification,
|
||||||
|
TFCamembertForTokenClassification,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_ctrl import (
|
||||||
|
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFCTRLLMHeadModel,
|
||||||
|
TFCTRLModel,
|
||||||
|
TFCTRLPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_distilbert import (
|
||||||
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFDistilBertForMaskedLM,
|
||||||
|
TFDistilBertForMultipleChoice,
|
||||||
|
TFDistilBertForQuestionAnswering,
|
||||||
|
TFDistilBertForSequenceClassification,
|
||||||
|
TFDistilBertForTokenClassification,
|
||||||
|
TFDistilBertMainLayer,
|
||||||
|
TFDistilBertModel,
|
||||||
|
TFDistilBertPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_tf_electra import (
|
from .modeling_tf_electra import (
|
||||||
TFElectraPreTrainedModel,
|
|
||||||
TFElectraModel,
|
|
||||||
TFElectraForPreTraining,
|
|
||||||
TFElectraForMaskedLM,
|
|
||||||
TFElectraForTokenClassification,
|
|
||||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFElectraForMaskedLM,
|
||||||
|
TFElectraForPreTraining,
|
||||||
|
TFElectraForQuestionAnswering,
|
||||||
|
TFElectraForTokenClassification,
|
||||||
|
TFElectraModel,
|
||||||
|
TFElectraPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_flaubert import (
|
||||||
|
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFFlaubertForMultipleChoice,
|
||||||
|
TFFlaubertForQuestionAnsweringSimple,
|
||||||
|
TFFlaubertForSequenceClassification,
|
||||||
|
TFFlaubertForTokenClassification,
|
||||||
|
TFFlaubertWithLMHeadModel,
|
||||||
|
TFFlaubertModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_gpt2 import (
|
||||||
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFGPT2DoubleHeadsModel,
|
||||||
|
TFGPT2LMHeadModel,
|
||||||
|
TFGPT2MainLayer,
|
||||||
|
TFGPT2Model,
|
||||||
|
TFGPT2PreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_openai import (
|
||||||
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFOpenAIGPTDoubleHeadsModel,
|
||||||
|
TFOpenAIGPTLMHeadModel,
|
||||||
|
TFOpenAIGPTMainLayer,
|
||||||
|
TFOpenAIGPTModel,
|
||||||
|
TFOpenAIGPTPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_roberta import (
|
||||||
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFRobertaForMaskedLM,
|
||||||
|
TFRobertaForMultipleChoice,
|
||||||
|
TFRobertaForQuestionAnswering,
|
||||||
|
TFRobertaForSequenceClassification,
|
||||||
|
TFRobertaForTokenClassification,
|
||||||
|
TFRobertaMainLayer,
|
||||||
|
TFRobertaModel,
|
||||||
|
TFRobertaPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_t5 import (
|
||||||
|
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFT5ForConditionalGeneration,
|
||||||
|
TFT5Model,
|
||||||
|
TFT5PreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_transfo_xl import (
|
||||||
|
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFAdaptiveEmbedding,
|
||||||
|
TFTransfoXLLMHeadModel,
|
||||||
|
TFTransfoXLMainLayer,
|
||||||
|
TFTransfoXLModel,
|
||||||
|
TFTransfoXLPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_xlm import (
|
||||||
|
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFXLMForMultipleChoice,
|
||||||
|
TFXLMForQuestionAnsweringSimple,
|
||||||
|
TFXLMForSequenceClassification,
|
||||||
|
TFXLMForTokenClassification,
|
||||||
|
TFXLMWithLMHeadModel,
|
||||||
|
TFXLMMainLayer,
|
||||||
|
TFXLMModel,
|
||||||
|
TFXLMPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_xlm_roberta import (
|
||||||
|
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFXLMRobertaForMaskedLM,
|
||||||
|
TFXLMRobertaForMultipleChoice,
|
||||||
|
TFXLMRobertaForQuestionAnswering,
|
||||||
|
TFXLMRobertaForSequenceClassification,
|
||||||
|
TFXLMRobertaForTokenClassification,
|
||||||
|
TFXLMRobertaModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .modeling_tf_xlnet import (
|
||||||
|
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
TFXLNetForMultipleChoice,
|
||||||
|
TFXLNetForQuestionAnsweringSimple,
|
||||||
|
TFXLNetForSequenceClassification,
|
||||||
|
TFXLNetForTokenClassification,
|
||||||
|
TFXLNetLMHeadModel,
|
||||||
|
TFXLNetMainLayer,
|
||||||
|
TFXLNetModel,
|
||||||
|
TFXLNetPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
from .optimization_tf import (
|
||||||
|
AdamWeightDecay,
|
||||||
|
create_optimizer,
|
||||||
|
GradientAccumulator,
|
||||||
|
WarmUp,
|
||||||
|
)
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer_tf import TFTrainer
|
from .trainer_tf import TFTrainer
|
||||||
|
@ -394,8 +394,8 @@ def squad_convert_examples_to_features(
|
|||||||
"qas_id": ex.qas_id,
|
"qas_id": ex.qas_id,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"start_position": ex.start_position,
|
"start_positions": ex.start_position,
|
||||||
"end_position": ex.end_position,
|
"end_positions": ex.end_position,
|
||||||
"cls_index": ex.cls_index,
|
"cls_index": ex.cls_index,
|
||||||
"p_mask": ex.p_mask,
|
"p_mask": ex.p_mask,
|
||||||
"is_impossible": ex.is_impossible,
|
"is_impossible": ex.is_impossible,
|
||||||
@ -412,8 +412,8 @@ def squad_convert_examples_to_features(
|
|||||||
"qas_id": tf.string,
|
"qas_id": tf.string,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"start_position": tf.int64,
|
"start_positions": tf.int64,
|
||||||
"end_position": tf.int64,
|
"end_positions": tf.int64,
|
||||||
"cls_index": tf.int64,
|
"cls_index": tf.int64,
|
||||||
"p_mask": tf.int32,
|
"p_mask": tf.int32,
|
||||||
"is_impossible": tf.int32,
|
"is_impossible": tf.int32,
|
||||||
@ -429,8 +429,8 @@ def squad_convert_examples_to_features(
|
|||||||
"qas_id": tf.TensorShape([]),
|
"qas_id": tf.TensorShape([]),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"start_position": tf.TensorShape([]),
|
"start_positions": tf.TensorShape([]),
|
||||||
"end_position": tf.TensorShape([]),
|
"end_positions": tf.TensorShape([]),
|
||||||
"cls_index": tf.TensorShape([]),
|
"cls_index": tf.TensorShape([]),
|
||||||
"p_mask": tf.TensorShape([None]),
|
"p_mask": tf.TensorShape([None]),
|
||||||
"is_impossible": tf.TensorShape([]),
|
"is_impossible": tf.TensorShape([]),
|
||||||
|
@ -81,6 +81,8 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
kwargs["type"] = field.type
|
kwargs["type"] = field.type
|
||||||
if field.default is not dataclasses.MISSING:
|
if field.default is not dataclasses.MISSING:
|
||||||
kwargs["default"] = field.default
|
kwargs["default"] = field.default
|
||||||
|
elif field.default_factory is not dataclasses.MISSING:
|
||||||
|
kwargs["default"] = field.default_factory()
|
||||||
else:
|
else:
|
||||||
kwargs["required"] = True
|
kwargs["required"] = True
|
||||||
self.add_argument(field_name, **kwargs)
|
self.add_argument(field_name, **kwargs)
|
||||||
|
@ -23,7 +23,16 @@ import tensorflow as tf
|
|||||||
from .configuration_albert import AlbertConfig
|
from .configuration_albert import AlbertConfig
|
||||||
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
from .modeling_tf_bert import ACT2FN, TFBertSelfAttention
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMultipleChoiceLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
@ -841,7 +850,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
|||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
ALBERT_START_DOCSTRING,
|
ALBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -852,9 +861,25 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the sequence classification/regression loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||||
|
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||||
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`)
|
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`)
|
||||||
@ -878,27 +903,126 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
|||||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
model = TFAlbertForSequenceClassification.from_pretrained('albert-base-v2')
|
model = TFAlbertForSequenceClassification.from_pretrained('albert-base-v2')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
||||||
logits = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.albert(inputs, **kwargs)
|
|
||||||
|
outputs = self.albert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # logits, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Albert Model with a token classification head on top (a linear layer on top of
|
||||||
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
|
ALBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||||
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = tf.keras.layers.Dense(
|
||||||
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import AlbertTokenizer, TFAlbertForTokenClassification
|
||||||
|
|
||||||
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
|
model = TFAlbertForTokenClassification.from_pretrained('albert-base-v2')
|
||||||
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
outputs = self.albert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
sequence_output = self.dropout(sequence_output, training=training)
|
||||||
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
ALBERT_START_DOCSTRING,
|
ALBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -908,9 +1032,32 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
||||||
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
@ -938,14 +1085,23 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
model = TFAlbertForQuestionAnswering.from_pretrained('albert-base-v2')
|
model = TFAlbertForQuestionAnswering.from_pretrained('albert-base-v2')
|
||||||
input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet")
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
|
start_scores, end_scores = model(input_dict)
|
||||||
|
|
||||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.albert(inputs, **kwargs)
|
outputs = self.albert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@ -956,7 +1112,13 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
|||||||
|
|
||||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||||
|
|
||||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -964,7 +1126,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
|||||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
ALBERT_START_DOCSTRING,
|
ALBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@ -992,9 +1154,15 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the multiple choice classification loss.
|
||||||
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
@ -1019,12 +1187,13 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||||
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
|
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
|
||||||
|
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||||
|
|
||||||
example1 = ["This is a context", "Is it a context? Yes"]
|
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||||
example2 = ["This is a context", "Is it a context? No"]
|
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||||
encoding = tokenizer.batch_encode_plus([example1, example2], return_tensors='tf', truncation_strategy="only_first", pad_to_max_length=True, max_length=128)
|
outputs = model(input_ids, labels=labels)
|
||||||
outputs = model(encoding["input_ids"][None, :])
|
|
||||||
logits = outputs[0]
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
@ -1036,10 +1205,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
|||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, dict):
|
||||||
print("isdict(1)")
|
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
print(input_ids)
|
|
||||||
|
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
position_ids = inputs.get("position_ids", position_ids)
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
@ -1080,4 +1246,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
|
|||||||
|
|
||||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # reshaped_logits, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, reshaped_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||||
|
@ -22,14 +22,18 @@ from .configuration_auto import (
|
|||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
CamembertConfig,
|
||||||
CTRLConfig,
|
CTRLConfig,
|
||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
|
ElectraConfig,
|
||||||
|
FlaubertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
T5Config,
|
T5Config,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
|
XLMRobertaConfig,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
)
|
)
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
@ -39,6 +43,7 @@ from .modeling_tf_albert import (
|
|||||||
TFAlbertForPreTraining,
|
TFAlbertForPreTraining,
|
||||||
TFAlbertForQuestionAnswering,
|
TFAlbertForQuestionAnswering,
|
||||||
TFAlbertForSequenceClassification,
|
TFAlbertForSequenceClassification,
|
||||||
|
TFAlbertForTokenClassification,
|
||||||
TFAlbertModel,
|
TFAlbertModel,
|
||||||
)
|
)
|
||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
@ -50,18 +55,43 @@ from .modeling_tf_bert import (
|
|||||||
TFBertForTokenClassification,
|
TFBertForTokenClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
)
|
)
|
||||||
|
from .modeling_tf_camembert import (
|
||||||
|
TFCamembertForMaskedLM,
|
||||||
|
TFCamembertForMultipleChoice,
|
||||||
|
TFCamembertForQuestionAnswering,
|
||||||
|
TFCamembertForSequenceClassification,
|
||||||
|
TFCamembertForTokenClassification,
|
||||||
|
TFCamembertModel,
|
||||||
|
)
|
||||||
from .modeling_tf_ctrl import TFCTRLLMHeadModel, TFCTRLModel
|
from .modeling_tf_ctrl import TFCTRLLMHeadModel, TFCTRLModel
|
||||||
from .modeling_tf_distilbert import (
|
from .modeling_tf_distilbert import (
|
||||||
TFDistilBertForMaskedLM,
|
TFDistilBertForMaskedLM,
|
||||||
|
TFDistilBertForMultipleChoice,
|
||||||
TFDistilBertForQuestionAnswering,
|
TFDistilBertForQuestionAnswering,
|
||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
TFDistilBertForTokenClassification,
|
TFDistilBertForTokenClassification,
|
||||||
TFDistilBertModel,
|
TFDistilBertModel,
|
||||||
)
|
)
|
||||||
|
from .modeling_tf_electra import (
|
||||||
|
TFElectraForMaskedLM,
|
||||||
|
TFElectraForPreTraining,
|
||||||
|
TFElectraForQuestionAnswering,
|
||||||
|
TFElectraForTokenClassification,
|
||||||
|
TFElectraModel,
|
||||||
|
)
|
||||||
|
from .modeling_tf_flaubert import (
|
||||||
|
TFFlaubertForMultipleChoice,
|
||||||
|
TFFlaubertForQuestionAnsweringSimple,
|
||||||
|
TFFlaubertForSequenceClassification,
|
||||||
|
TFFlaubertForTokenClassification,
|
||||||
|
TFFlaubertModel,
|
||||||
|
TFFlaubertWithLMHeadModel,
|
||||||
|
)
|
||||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||||
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||||
from .modeling_tf_roberta import (
|
from .modeling_tf_roberta import (
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
|
TFRobertaForMultipleChoice,
|
||||||
TFRobertaForQuestionAnswering,
|
TFRobertaForQuestionAnswering,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
@ -70,12 +100,23 @@ from .modeling_tf_roberta import (
|
|||||||
from .modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
from .modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
||||||
from .modeling_tf_transfo_xl import TFTransfoXLLMHeadModel, TFTransfoXLModel
|
from .modeling_tf_transfo_xl import TFTransfoXLLMHeadModel, TFTransfoXLModel
|
||||||
from .modeling_tf_xlm import (
|
from .modeling_tf_xlm import (
|
||||||
|
TFXLMForMultipleChoice,
|
||||||
TFXLMForQuestionAnsweringSimple,
|
TFXLMForQuestionAnsweringSimple,
|
||||||
TFXLMForSequenceClassification,
|
TFXLMForSequenceClassification,
|
||||||
|
TFXLMForTokenClassification,
|
||||||
TFXLMModel,
|
TFXLMModel,
|
||||||
TFXLMWithLMHeadModel,
|
TFXLMWithLMHeadModel,
|
||||||
)
|
)
|
||||||
|
from .modeling_tf_xlm_roberta import (
|
||||||
|
TFXLMRobertaForMaskedLM,
|
||||||
|
TFXLMRobertaForMultipleChoice,
|
||||||
|
TFXLMRobertaForQuestionAnswering,
|
||||||
|
TFXLMRobertaForSequenceClassification,
|
||||||
|
TFXLMRobertaForTokenClassification,
|
||||||
|
TFXLMRobertaModel,
|
||||||
|
)
|
||||||
from .modeling_tf_xlnet import (
|
from .modeling_tf_xlnet import (
|
||||||
|
TFXLNetForMultipleChoice,
|
||||||
TFXLNetForQuestionAnsweringSimple,
|
TFXLNetForQuestionAnsweringSimple,
|
||||||
TFXLNetForSequenceClassification,
|
TFXLNetForSequenceClassification,
|
||||||
TFXLNetForTokenClassification,
|
TFXLNetForTokenClassification,
|
||||||
@ -89,83 +130,118 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
TF_MODEL_MAPPING = OrderedDict(
|
TF_MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(T5Config, TFT5Model),
|
|
||||||
(DistilBertConfig, TFDistilBertModel),
|
|
||||||
(AlbertConfig, TFAlbertModel),
|
(AlbertConfig, TFAlbertModel),
|
||||||
(RobertaConfig, TFRobertaModel),
|
|
||||||
(BertConfig, TFBertModel),
|
(BertConfig, TFBertModel),
|
||||||
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
(CamembertConfig, TFCamembertModel),
|
||||||
(GPT2Config, TFGPT2Model),
|
|
||||||
(TransfoXLConfig, TFTransfoXLModel),
|
|
||||||
(XLNetConfig, TFXLNetModel),
|
|
||||||
(XLMConfig, TFXLMModel),
|
|
||||||
(CTRLConfig, TFCTRLModel),
|
(CTRLConfig, TFCTRLModel),
|
||||||
|
(DistilBertConfig, TFDistilBertModel),
|
||||||
|
(ElectraConfig, TFElectraModel),
|
||||||
|
(FlaubertConfig, TFFlaubertModel),
|
||||||
|
(GPT2Config, TFGPT2Model),
|
||||||
|
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
||||||
|
(RobertaConfig, TFRobertaModel),
|
||||||
|
(T5Config, TFT5Model),
|
||||||
|
(TransfoXLConfig, TFTransfoXLModel),
|
||||||
|
(XLMConfig, TFXLMModel),
|
||||||
|
(XLMRobertaConfig, TFXLMRobertaModel),
|
||||||
|
(XLNetConfig, TFXLNetModel),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(T5Config, TFT5ForConditionalGeneration),
|
|
||||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
|
||||||
(AlbertConfig, TFAlbertForPreTraining),
|
(AlbertConfig, TFAlbertForPreTraining),
|
||||||
(RobertaConfig, TFRobertaForMaskedLM),
|
|
||||||
(BertConfig, TFBertForPreTraining),
|
(BertConfig, TFBertForPreTraining),
|
||||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
(CamembertConfig, TFCamembertForMaskedLM),
|
||||||
(GPT2Config, TFGPT2LMHeadModel),
|
|
||||||
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
|
||||||
(XLNetConfig, TFXLNetLMHeadModel),
|
|
||||||
(XLMConfig, TFXLMWithLMHeadModel),
|
|
||||||
(CTRLConfig, TFCTRLLMHeadModel),
|
(CTRLConfig, TFCTRLLMHeadModel),
|
||||||
|
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||||
|
(ElectraConfig, TFElectraForPreTraining),
|
||||||
|
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||||
|
(GPT2Config, TFGPT2LMHeadModel),
|
||||||
|
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||||
|
(RobertaConfig, TFRobertaForMaskedLM),
|
||||||
|
(T5Config, TFT5ForConditionalGeneration),
|
||||||
|
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||||
|
(XLMConfig, TFXLMWithLMHeadModel),
|
||||||
|
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||||
|
(XLNetConfig, TFXLNetLMHeadModel),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(T5Config, TFT5ForConditionalGeneration),
|
|
||||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
|
||||||
(AlbertConfig, TFAlbertForMaskedLM),
|
(AlbertConfig, TFAlbertForMaskedLM),
|
||||||
(RobertaConfig, TFRobertaForMaskedLM),
|
|
||||||
(BertConfig, TFBertForMaskedLM),
|
(BertConfig, TFBertForMaskedLM),
|
||||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
(CamembertConfig, TFCamembertForMaskedLM),
|
||||||
(GPT2Config, TFGPT2LMHeadModel),
|
|
||||||
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
|
||||||
(XLNetConfig, TFXLNetLMHeadModel),
|
|
||||||
(XLMConfig, TFXLMWithLMHeadModel),
|
|
||||||
(CTRLConfig, TFCTRLLMHeadModel),
|
(CTRLConfig, TFCTRLLMHeadModel),
|
||||||
|
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||||
|
(ElectraConfig, TFElectraForMaskedLM),
|
||||||
|
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||||
|
(GPT2Config, TFGPT2LMHeadModel),
|
||||||
|
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||||
|
(RobertaConfig, TFRobertaForMaskedLM),
|
||||||
|
(T5Config, TFT5ForConditionalGeneration),
|
||||||
|
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||||
|
(XLMConfig, TFXLMWithLMHeadModel),
|
||||||
|
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||||
|
(XLNetConfig, TFXLNetLMHeadModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(AlbertConfig, TFAlbertForMultipleChoice),
|
||||||
|
(BertConfig, TFBertForMultipleChoice),
|
||||||
|
(CamembertConfig, TFCamembertForMultipleChoice),
|
||||||
|
(DistilBertConfig, TFDistilBertForMultipleChoice),
|
||||||
|
(FlaubertConfig, TFFlaubertForMultipleChoice),
|
||||||
|
(RobertaConfig, TFRobertaForMultipleChoice),
|
||||||
|
(XLMConfig, TFXLMForMultipleChoice),
|
||||||
|
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
|
||||||
|
(XLNetConfig, TFXLNetForMultipleChoice),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(AlbertConfig, TFAlbertForQuestionAnswering),
|
||||||
|
(BertConfig, TFBertForQuestionAnswering),
|
||||||
|
(CamembertConfig, TFCamembertForQuestionAnswering),
|
||||||
|
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||||
|
(ElectraConfig, TFElectraForQuestionAnswering),
|
||||||
|
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
|
||||||
|
(RobertaConfig, TFRobertaForQuestionAnswering),
|
||||||
|
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||||
|
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
|
||||||
|
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
|
||||||
(AlbertConfig, TFAlbertForSequenceClassification),
|
(AlbertConfig, TFAlbertForSequenceClassification),
|
||||||
(RobertaConfig, TFRobertaForSequenceClassification),
|
|
||||||
(BertConfig, TFBertForSequenceClassification),
|
(BertConfig, TFBertForSequenceClassification),
|
||||||
(XLNetConfig, TFXLNetForSequenceClassification),
|
(CamembertConfig, TFCamembertForSequenceClassification),
|
||||||
|
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||||
|
(FlaubertConfig, TFFlaubertForSequenceClassification),
|
||||||
|
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||||
(XLMConfig, TFXLMForSequenceClassification),
|
(XLMConfig, TFXLMForSequenceClassification),
|
||||||
]
|
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
|
||||||
)
|
(XLNetConfig, TFXLNetForSequenceClassification),
|
||||||
|
|
||||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
|
||||||
[(BertConfig, TFBertForMultipleChoice), (AlbertConfig, TFAlbertForMultipleChoice)]
|
|
||||||
)
|
|
||||||
|
|
||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|
||||||
[
|
|
||||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
|
||||||
(AlbertConfig, TFAlbertForQuestionAnswering),
|
|
||||||
(RobertaConfig, TFRobertaForQuestionAnswering),
|
|
||||||
(BertConfig, TFBertForQuestionAnswering),
|
|
||||||
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
|
|
||||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(DistilBertConfig, TFDistilBertForTokenClassification),
|
(AlbertConfig, TFAlbertForTokenClassification),
|
||||||
(RobertaConfig, TFRobertaForTokenClassification),
|
|
||||||
(BertConfig, TFBertForTokenClassification),
|
(BertConfig, TFBertForTokenClassification),
|
||||||
|
(CamembertConfig, TFCamembertForTokenClassification),
|
||||||
|
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||||
|
(ElectraConfig, TFElectraForTokenClassification),
|
||||||
|
(FlaubertConfig, TFFlaubertForTokenClassification),
|
||||||
|
(RobertaConfig, TFRobertaForTokenClassification),
|
||||||
|
(XLMConfig, TFXLMForTokenClassification),
|
||||||
|
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
|
||||||
(XLNetConfig, TFXLNetForTokenClassification),
|
(XLNetConfig, TFXLNetForTokenClassification),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -632,11 +708,13 @@ class TFAutoModelWithLMHead(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
|
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||||
if isinstance(config, config_class):
|
# Not using isinstance() here to do not take into account inheritance
|
||||||
|
if config_class == type(config):
|
||||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
|
@ -23,7 +23,16 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMultipleChoiceLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
@ -880,7 +889,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
|||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -891,9 +900,25 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the sequence classification/regression loss.
|
||||||
|
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
||||||
|
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
@ -916,21 +941,35 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
|
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
||||||
logits = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.bert(inputs, **kwargs)
|
|
||||||
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # logits, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -938,7 +977,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
|
|||||||
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@ -966,9 +1005,15 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the multiple choice classification loss.
|
||||||
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
@ -993,15 +1038,14 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
|
model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
|
||||||
|
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||||
|
|
||||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||||
choice0 = "It is eaten with a fork and a knife."
|
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||||
choice1 = "It is eaten while held in the hand."
|
outputs = model(input_ids, labels=labels)
|
||||||
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='tf', pad_to_max_length=True)
|
|
||||||
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
# linear classifier on the output is not yet trained
|
|
||||||
outputs = model(encoding['input_ids'][None, :])
|
|
||||||
logits = outputs[0]
|
|
||||||
"""
|
"""
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
@ -1011,7 +1055,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
assert len(inputs) <= 6, "Too many inputs."
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
input_ids = inputs.get("input_ids")
|
input_ids = inputs.get("input_ids")
|
||||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
@ -1053,7 +1097,11 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # reshaped_logits, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, reshaped_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1061,7 +1109,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBertForTokenClassification(TFBertPreTrainedModel):
|
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -1072,9 +1120,23 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
@ -1098,20 +1160,33 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
|
|||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
|
model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
scores = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.bert(inputs, **kwargs)
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
sequence_output = self.dropout(sequence_output, training=training)
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # scores, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1119,7 +1194,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
|
|||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -1129,9 +1204,32 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
@ -1156,18 +1254,24 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = TFBertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
|
model = TFBertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
|
||||||
|
|
||||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
encoding = tokenizer.encode_plus(question, text)
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
|
start_scores, end_scores = model(input_dict)
|
||||||
start_scores, end_scores = model(tf.constant(input_ids)[None, :], token_type_ids=tf.constant(token_type_ids)[None, :])
|
|
||||||
|
|
||||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
answer = ' '.join(all_tokens[tf.math.argmax(tf.squeeze(start_scores)) : tf.math.argmax(tf.squeeze(end_scores))+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
assert answer == "a nice puppet"
|
assert answer == "a nice puppet"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.bert(inputs, **kwargs)
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@ -1178,4 +1282,10 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||||
|
|
||||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
@ -22,6 +22,8 @@ from .configuration_camembert import CamembertConfig
|
|||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_roberta import (
|
from .modeling_tf_roberta import (
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
|
TFRobertaForMultipleChoice,
|
||||||
|
TFRobertaForQuestionAnswering,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
TFRobertaModel,
|
TFRobertaModel,
|
||||||
@ -114,3 +116,30 @@ class TFCamembertForTokenClassification(TFRobertaForTokenClassification):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = CamembertConfig
|
config_class = CamembertConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFCamembertForMultipleChoice(TFRobertaForMultipleChoice):
|
||||||
|
"""
|
||||||
|
This class overrides :class:`~transformers.TFRobertaForMultipleChoice`. Please check the
|
||||||
|
superclass for the appropriate documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = CamembertConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""CamemBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
CAMEMBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFCamembertForQuestionAnswering(TFRobertaForQuestionAnswering):
|
||||||
|
"""
|
||||||
|
This class overrides :class:`~transformers.TFRobertaForQuestionAnswering`. Please check the
|
||||||
|
superclass for the appropriate documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = CamembertConfig
|
||||||
|
@ -23,8 +23,18 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMultipleChoiceLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
|
TFSharedEmbeddings,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
@ -399,7 +409,10 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = DistilBertConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
@ -662,7 +675,7 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
|
|||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
DISTILBERT_START_DOCSTRING,
|
DISTILBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
|
class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -680,8 +693,16 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
|
|||||||
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
|
self.dropout = tf.keras.layers.Dropout(config.seq_classif_dropout)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the sequence classification/regression loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||||
|
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
||||||
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
@ -705,20 +726,32 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
|
|||||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||||
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
|
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-cased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
||||||
logits = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
distilbert_output = self.distilbert(inputs, **kwargs)
|
distilbert_output = self.distilbert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||||
pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False)) # (bs, dim)
|
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
||||||
logits = self.classifier(pooled_output) # (bs, dim)
|
logits = self.classifier(pooled_output) # (bs, dim)
|
||||||
|
|
||||||
outputs = (logits,) + distilbert_output[1:]
|
outputs = (logits,) + distilbert_output[1:]
|
||||||
return outputs # logits, (hidden_states), (attentions)
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -726,7 +759,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
|
|||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
DISTILBERT_START_DOCSTRING,
|
DISTILBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
|
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -738,8 +771,14 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
||||||
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
@ -762,20 +801,154 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||||
model = TFDistilBertForTokenClassification.from_pretrained('distilbert-base-cased')
|
model = TFDistilBertForTokenClassification.from_pretrained('distilbert-base-cased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
scores = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.distilbert(inputs, **kwargs)
|
outputs = self.distilbert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
sequence_output = self.dropout(sequence_output, training=training)
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # scores, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""DistilBert Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
DISTILBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
|
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
|
||||||
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.pre_classifier = tf.keras.layers.Dense(
|
||||||
|
config.dim,
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
activation="relu",
|
||||||
|
name="pre_classifier",
|
||||||
|
)
|
||||||
|
self.classifier = tf.keras.layers.Dense(
|
||||||
|
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the multiple choice classification loss.
|
||||||
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
|
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import DistilBertTokenizer, TFDistilBertForMultipleChoice
|
||||||
|
|
||||||
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
||||||
|
model = TFDistilBertForMultipleChoice.from_pretrained('distilbert-base-uncased')
|
||||||
|
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||||
|
|
||||||
|
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||||
|
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
|
||||||
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
head_mask = inputs[2] if len(inputs) > 2 else head_mask
|
||||||
|
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
|
||||||
|
assert len(inputs) <= 4, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
assert len(inputs) <= 4, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
num_choices = shape_list(input_ids)[1]
|
||||||
|
seq_length = shape_list(input_ids)[2]
|
||||||
|
else:
|
||||||
|
num_choices = shape_list(inputs_embeds)[1]
|
||||||
|
seq_length = shape_list(inputs_embeds)[2]
|
||||||
|
|
||||||
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
|
|
||||||
|
flat_inputs = [
|
||||||
|
flat_input_ids,
|
||||||
|
flat_attention_mask,
|
||||||
|
head_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
]
|
||||||
|
|
||||||
|
distilbert_output = self.distilbert(flat_inputs, training=training)
|
||||||
|
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||||
|
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||||
|
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||||
|
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
|
outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, reshaped_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -783,7 +956,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
|
|||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
DISTILBERT_START_DOCSTRING,
|
DISTILBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
|
class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
@ -795,8 +968,29 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
|
|||||||
self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
|
self.dropout = tf.keras.layers.Dropout(config.qa_dropout)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
|
||||||
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
@ -821,19 +1015,35 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
||||||
model = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased')
|
model = TFDistilBertForQuestionAnswering.from_pretrained('distilbert-base-cased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
outputs = model(input_ids)
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
start_scores, end_scores = outputs[:2]
|
start_scores, end_scores = model(input_dict)
|
||||||
|
|
||||||
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
distilbert_output = self.distilbert(inputs, **kwargs)
|
distilbert_output = self.distilbert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
||||||
hidden_states = self.dropout(hidden_states, training=kwargs.get("training", False)) # (bs, max_query_len, dim)
|
hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
|
||||||
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
|
logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
|
||||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||||
|
|
||||||
outputs = (start_logits, end_logits,) + distilbert_output[1:]
|
outputs = (start_logits, end_logits,) + distilbert_output[1:]
|
||||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
@ -6,7 +6,13 @@ from transformers import ElectraConfig
|
|||||||
|
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
||||||
from .modeling_tf_utils import get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
@ -194,6 +200,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
|
|||||||
return head_mask
|
return head_mask
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFElectraMainLayer(TFElectraPreTrainedModel):
|
class TFElectraMainLayer(TFElectraPreTrainedModel):
|
||||||
|
|
||||||
config_class = ElectraConfig
|
config_class = ElectraConfig
|
||||||
@ -557,13 +564,15 @@ Electra model with a token classification head on top.
|
|||||||
Both the discriminator and generator may be loaded into this model.""",
|
Both the discriminator and generator may be loaded into this model.""",
|
||||||
ELECTRA_START_DOCSTRING,
|
ELECTRA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFElectraForTokenClassification(TFElectraPreTrainedModel):
|
class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
|
||||||
self.electra = TFElectraMainLayer(config, name="electra")
|
self.electra = TFElectraMainLayer(config, name="electra")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
|
self.classifier = tf.keras.layers.Dense(
|
||||||
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||||
def call(
|
def call(
|
||||||
@ -574,9 +583,14 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
|
||||||
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
@ -599,9 +613,11 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
|
||||||
model = TFElectraForTokenClassification.from_pretrained('google/electra-small-discriminator')
|
model = TFElectraForTokenClassification.from_pretrained('google/electra-small-discriminator')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
scores = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
discriminator_hidden_states = self.electra(
|
discriminator_hidden_states = self.electra(
|
||||||
@ -610,7 +626,106 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel):
|
|||||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||||
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
|
||||||
logits = self.classifier(discriminator_sequence_output)
|
logits = self.classifier(discriminator_sequence_output)
|
||||||
output = (logits,)
|
|
||||||
output += discriminator_hidden_states[1:]
|
|
||||||
|
|
||||||
return output # (loss), scores, (hidden_states), (attentions)
|
outputs = (logits,) + discriminator_hidden_states[1:]
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), scores, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
ELECTRA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.electra = TFElectraMainLayer(config, name="electra")
|
||||||
|
self.qa_outputs = tf.keras.layers.Dense(
|
||||||
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||||
|
)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import ElectraTokenizer, TFElectraForQuestionAnswering
|
||||||
|
|
||||||
|
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-generator')
|
||||||
|
model = TFElectraForQuestionAnswering.from_pretrained('google/electra-small-generator')
|
||||||
|
|
||||||
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
|
start_scores, end_scores = model(input_dict)
|
||||||
|
|
||||||
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
|
"""
|
||||||
|
discriminator_hidden_states = self.electra(
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, training=training
|
||||||
|
)
|
||||||
|
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(discriminator_sequence_output)
|
||||||
|
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||||
|
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||||
|
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||||
|
|
||||||
|
outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:]
|
||||||
|
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
@ -22,13 +22,16 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_flaubert import FlaubertConfig
|
from .configuration_flaubert import FlaubertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_tf_utils import keras_serializable, shape_list
|
||||||
from .modeling_tf_xlm import (
|
from .modeling_tf_xlm import (
|
||||||
|
TFXLMForMultipleChoice,
|
||||||
|
TFXLMForQuestionAnsweringSimple,
|
||||||
TFXLMForSequenceClassification,
|
TFXLMForSequenceClassification,
|
||||||
|
TFXLMForTokenClassification,
|
||||||
TFXLMMainLayer,
|
TFXLMMainLayer,
|
||||||
TFXLMModel,
|
TFXLMModel,
|
||||||
TFXLMWithLMHeadModel,
|
TFXLMWithLMHeadModel,
|
||||||
get_masks,
|
get_masks,
|
||||||
shape_list,
|
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
@ -112,6 +115,7 @@ class TFFlaubertModel(TFXLMModel):
|
|||||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFFlaubertMainLayer(TFXLMMainLayer):
|
class TFFlaubertMainLayer(TFXLMMainLayer):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
@ -327,3 +331,38 @@ class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
|
|||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
FLAUBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):
|
||||||
|
config_class = FlaubertConfig
|
||||||
|
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Flaubert Model with a token classification head on top (a linear layer on top of
|
||||||
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
|
FLAUBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Flaubert Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
FLAUBERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
|
@ -29,6 +29,7 @@ from .modeling_tf_utils import (
|
|||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
@ -199,7 +200,10 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
return outputs # x, (attentions)
|
return outputs # x, (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = OpenAIGPTConfig
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
@ -21,9 +21,18 @@ import logging
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
|
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMultipleChoiceLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -82,6 +91,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||||||
return super()._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
return super()._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFRobertaMainLayer(TFBertMainLayer):
|
class TFRobertaMainLayer(TFBertMainLayer):
|
||||||
"""
|
"""
|
||||||
Same as TFBertMainLayer but uses TFRobertaEmbeddings.
|
Same as TFBertMainLayer but uses TFRobertaEmbeddings.
|
||||||
@ -337,7 +347,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
|||||||
on top of the pooled output) e.g. for GLUE tasks. """,
|
on top of the pooled output) e.g. for GLUE tasks. """,
|
||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
|
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -346,7 +356,17 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
|
|||||||
self.classifier = TFRobertaClassificationHead(config, name="classifier")
|
self.classifier = TFRobertaClassificationHead(config, name="classifier")
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||||
@ -370,20 +390,164 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||||
model = TFRobertaForSequenceClassification.from_pretrained('roberta-base')
|
model = TFRobertaForSequenceClassification.from_pretrained('roberta-base')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
labels = tf.constant([1])[None, :] # Batch size 1
|
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
||||||
outputs = model(input_ids)
|
outputs = model(input_ids, labels=labels)
|
||||||
logits = outputs[0]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.roberta(inputs, **kwargs)
|
outputs = self.roberta(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
logits = self.classifier(sequence_output, training=kwargs.get("training", False))
|
logits = self.classifier(sequence_output, training=training)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:]
|
outputs = (logits,) + outputs[2:]
|
||||||
|
|
||||||
return outputs # logits, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
|
self.roberta = TFBertMainLayer(config, name="roberta")
|
||||||
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = tf.keras.layers.Dense(
|
||||||
|
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the multiple choice classification loss.
|
||||||
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
|
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import RobertaTokenizer, TFRobertaForMultipleChoice
|
||||||
|
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||||
|
model = TFRobertaForMultipleChoice.from_pretrained('roberta-base')
|
||||||
|
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||||
|
|
||||||
|
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||||
|
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
|
||||||
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
|
||||||
|
position_ids = inputs[3] if len(inputs) > 3 else position_ids
|
||||||
|
head_mask = inputs[4] if len(inputs) > 4 else head_mask
|
||||||
|
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
|
elif isinstance(inputs, dict):
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
assert len(inputs) <= 6, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
num_choices = shape_list(input_ids)[1]
|
||||||
|
seq_length = shape_list(input_ids)[2]
|
||||||
|
else:
|
||||||
|
num_choices = shape_list(inputs_embeds)[1]
|
||||||
|
seq_length = shape_list(inputs_embeds)[2]
|
||||||
|
|
||||||
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
|
flat_inputs = [
|
||||||
|
flat_input_ids,
|
||||||
|
flat_attention_mask,
|
||||||
|
flat_token_type_ids,
|
||||||
|
flat_position_ids,
|
||||||
|
head_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.roberta(flat_inputs, training=training)
|
||||||
|
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, reshaped_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -391,7 +555,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel):
|
|||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
|
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -403,8 +567,22 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||||
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
@ -428,27 +606,40 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
|
|||||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||||
model = TFRobertaForTokenClassification.from_pretrained('roberta-base')
|
model = TFRobertaForTokenClassification.from_pretrained('roberta-base')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
scores = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.roberta(inputs, **kwargs)
|
outputs = self.roberta(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
|
sequence_output = self.dropout(sequence_output, training=training)
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
return outputs # scores, (hidden_states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
"""RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
ROBERTA_START_DOCSTRING,
|
ROBERTA_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
|
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -459,8 +650,31 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||||
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
@ -488,14 +702,23 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||||
model = TFRobertaForQuestionAnswering.from_pretrained('roberta-base')
|
model = TFRobertaForQuestionAnswering.from_pretrained('roberta-base')
|
||||||
input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet")
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
|
start_scores, end_scores = model(input_dict)
|
||||||
|
|
||||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
outputs = self.roberta(inputs, **kwargs)
|
outputs = self.roberta(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
@ -506,4 +729,10 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
|
|||||||
|
|
||||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||||
|
|
||||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
@ -25,7 +25,8 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||||
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -502,7 +503,10 @@ class _NoLayerEmbedTokens(object):
|
|||||||
# The full model without a specific pretrained or finetuning head is
|
# The full model without a specific pretrained or finetuning head is
|
||||||
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
|
||||||
####################################################
|
####################################################
|
||||||
|
@keras_serializable
|
||||||
class TFT5MainLayer(tf.keras.layers.Layer):
|
class TFT5MainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = T5Config
|
||||||
|
|
||||||
def __init__(self, config, embed_tokens=None, **kwargs):
|
def __init__(self, config, embed_tokens=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
@ -548,12 +552,32 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
training=False,
|
training=False,
|
||||||
):
|
):
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states
|
||||||
|
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
|
||||||
|
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
|
||||||
|
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||||
|
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
input_ids = inputs.get("decoder_input_ids")
|
||||||
|
attention_mask = inputs.get("decoder_attention_mask", attention_mask)
|
||||||
|
encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
|
||||||
|
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
|
||||||
|
assert len(inputs) <= 7, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
if inputs is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
|
||||||
elif inputs is not None:
|
elif input_ids is not None:
|
||||||
input_shape = shape_list(inputs)
|
input_shape = shape_list(input_ids)
|
||||||
inputs = tf.reshape(inputs, (-1, input_shape[-1]))
|
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = shape_list(inputs_embeds)[:-1]
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
else:
|
else:
|
||||||
@ -561,7 +585,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
|
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
|
||||||
inputs_embeds = self.embed_tokens(inputs)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
|
@ -734,7 +734,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class TFTransfoXLLMHead(tf.keras.layers.Layer):
|
class TFTransfoXLMHead(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, input_embeddings, **kwargs):
|
def __init__(self, config, input_embeddings, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
@ -84,6 +84,7 @@ def keras_serializable(cls):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
|
raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
|
||||||
self._transformers_config = config
|
self._transformers_config = config
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
cls.__init__ = wrapped_init
|
cls.__init__ = wrapped_init
|
||||||
|
|
||||||
@ -94,6 +95,7 @@ def keras_serializable(cls):
|
|||||||
def get_config(self):
|
def get_config(self):
|
||||||
cfg = super(cls, self).get_config()
|
cfg = super(cls, self).get_config()
|
||||||
cfg["transformers_config"] = self._transformers_config.to_dict()
|
cfg["transformers_config"] = self._transformers_config.to_dict()
|
||||||
|
cfg.update(self._kwargs)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
cls.get_config = get_config
|
cls.get_config = get_config
|
||||||
@ -104,6 +106,44 @@ def keras_serializable(cls):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class TFQuestionAnsweringLoss:
|
||||||
|
def compute_loss(self, labels, logits):
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
)
|
||||||
|
start_loss = loss_fn(labels["start_position"], logits[0])
|
||||||
|
end_loss = loss_fn(labels["end_position"], logits[1])
|
||||||
|
|
||||||
|
return (start_loss + end_loss) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
class TFTokenClassificationLoss:
|
||||||
|
def compute_loss(self, labels, logits):
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
)
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||||
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|
||||||
|
return loss_fn(labels, reduced_logits)
|
||||||
|
|
||||||
|
|
||||||
|
class TFSequenceClassificationLoss:
|
||||||
|
def compute_loss(self, labels, logits):
|
||||||
|
if shape_list(logits)[1] == 1:
|
||||||
|
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
|
||||||
|
else:
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
|
||||||
|
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss_fn(labels, logits)
|
||||||
|
|
||||||
|
|
||||||
|
TFMultipleChoiceLoss = TFSequenceClassificationLoss
|
||||||
|
|
||||||
|
|
||||||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||||
r""" Base class for all TF models.
|
r""" Base class for all TF models.
|
||||||
|
|
||||||
@ -1531,6 +1571,16 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {
|
||||||
|
"vocab_size": self.vocab_size,
|
||||||
|
"hidden_size": self.hidden_size,
|
||||||
|
"initializer_range": self.initializer_range,
|
||||||
|
}
|
||||||
|
base_config = super().get_config()
|
||||||
|
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
def call(self, inputs, mode="embedding"):
|
def call(self, inputs, mode="embedding"):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
|
@ -24,8 +24,19 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_xlm import XLMConfig
|
from .configuration_xlm import XLMConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
|
from .modeling_tf_utils import (
|
||||||
|
TFMultipleChoiceLoss,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
|
TFSequenceSummary,
|
||||||
|
TFSharedEmbeddings,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
|
get_initializer,
|
||||||
|
keras_serializable,
|
||||||
|
shape_list,
|
||||||
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
@ -198,7 +209,10 @@ class TFTransformerFFN(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@keras_serializable
|
||||||
class TFXLMMainLayer(tf.keras.layers.Layer):
|
class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = XLMConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
@ -717,7 +731,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
|||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
XLM_START_DOCSTRING,
|
XLM_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
|
class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -726,8 +740,27 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
|
|||||||
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
langs=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
lengths=None,
|
||||||
|
cache=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the sequence classification/regression loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||||
|
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
|
||||||
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
@ -751,19 +784,261 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||||
model = TFXLMForSequenceClassification.from_pretrained('xlm-mlm-en-2048')
|
model = TFXLMForSequenceClassification.from_pretrained('xlm-mlm-en-2048')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
labels = tf.constant([1])[None, :] # Batch size 1
|
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
||||||
outputs = model(input_ids)
|
outputs = model(input_ids, labels=labels)
|
||||||
logits = outputs[0]
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
langs=langs,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
lengths=lengths,
|
||||||
|
cache=cache,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
logits = self.sequence_summary(output)
|
logits = self.sequence_summary(output)
|
||||||
|
|
||||||
outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
||||||
return outputs
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""XLM Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
XLM_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
|
self.transformer = TFXLMMainLayer(config, name="transformer")
|
||||||
|
self.sequence_summary = TFSequenceSummary(config, initializer_range=config.init_std, name="sequence_summary")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
attention_mask=None,
|
||||||
|
langs=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
lengths=None,
|
||||||
|
cache=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the multiple choice classification loss.
|
||||||
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
|
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import XLMTokenizer, TFXLMForMultipleChoice
|
||||||
|
|
||||||
|
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||||
|
model = TFXLMForMultipleChoice.from_pretrained('xlm-mlm-en-2048')
|
||||||
|
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||||
|
|
||||||
|
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||||
|
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
|
||||||
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
langs = inputs[2] if len(inputs) > 2 else langs
|
||||||
|
token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
|
||||||
|
position_ids = inputs[4] if len(inputs) > 4 else position_ids
|
||||||
|
lengths = inputs[5] if len(inputs) > 5 else lengths
|
||||||
|
cache = inputs[6] if len(inputs) > 6 else cache
|
||||||
|
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||||
|
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
langs = inputs.get("langs", langs)
|
||||||
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
|
position_ids = inputs.get("position_ids", position_ids)
|
||||||
|
lengths = inputs.get("lengths", lengths)
|
||||||
|
cache = inputs.get("cache", cache)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
assert len(inputs) <= 9, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
num_choices = shape_list(input_ids)[1]
|
||||||
|
seq_length = shape_list(input_ids)[2]
|
||||||
|
else:
|
||||||
|
num_choices = shape_list(inputs_embeds)[1]
|
||||||
|
seq_length = shape_list(inputs_embeds)[2]
|
||||||
|
|
||||||
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
|
flat_inputs = [
|
||||||
|
flat_input_ids,
|
||||||
|
flat_attention_mask,
|
||||||
|
langs,
|
||||||
|
flat_token_type_ids,
|
||||||
|
flat_position_ids,
|
||||||
|
lengths,
|
||||||
|
cache,
|
||||||
|
head_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
]
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||||
|
output = transformer_outputs[0]
|
||||||
|
logits = self.sequence_summary(output)
|
||||||
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
|
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, reshaped_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""XLM Model with a token classification head on top (a linear layer on top of
|
||||||
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
|
XLM_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.transformer = TFXLMMainLayer(config, name="transformer")
|
||||||
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
|
self.classifier = tf.keras.layers.Dense(
|
||||||
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
langs=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
lengths=None,
|
||||||
|
cache=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import XLMTokenizer, TFXLMForTokenClassification
|
||||||
|
|
||||||
|
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||||
|
model = TFXLMForTokenClassification.from_pretrained('xlm-mlm-en-2048')
|
||||||
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = transformer_outputs[0]
|
||||||
|
|
||||||
|
sequence_output = self.dropout(sequence_output, training=training)
|
||||||
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
|
outputs = (logits,) + transformer_outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -771,7 +1046,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel):
|
|||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
XLM_START_DOCSTRING,
|
XLM_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
|
class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFXLMMainLayer(config, name="transformer")
|
self.transformer = TFXLMMainLayer(config, name="transformer")
|
||||||
@ -780,8 +1055,34 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
langs=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
lengths=None,
|
||||||
|
cache=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
|
||||||
start_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
|
start_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
@ -807,12 +1108,27 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||||
model = TFXLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048')
|
model = TFXLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
outputs = model(input_ids)
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
start_scores, end_scores = outputs[:2]
|
start_scores, end_scores = model(input_dict)
|
||||||
|
|
||||||
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
langs=langs,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
lengths=lengths,
|
||||||
|
cache=cache,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = transformer_outputs[0]
|
sequence_output = transformer_outputs[0]
|
||||||
|
|
||||||
@ -825,4 +1141,10 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel):
|
|||||||
1:
|
1:
|
||||||
] # Keep mems, hidden states, attentions if there are in it
|
] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
@ -22,6 +22,8 @@ from .configuration_xlm_roberta import XLMRobertaConfig
|
|||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_roberta import (
|
from .modeling_tf_roberta import (
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
|
TFRobertaForMultipleChoice,
|
||||||
|
TFRobertaForQuestionAnswering,
|
||||||
TFRobertaForSequenceClassification,
|
TFRobertaForSequenceClassification,
|
||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
TFRobertaModel,
|
TFRobertaModel,
|
||||||
@ -114,3 +116,30 @@ class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = XLMRobertaConfig
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFXLMRobertaForQuestionAnswering(TFRobertaForQuestionAnswering):
|
||||||
|
"""
|
||||||
|
This class overrides :class:`~transformers.TFRobertaForQuestionAnsweringSimple`. Please check the
|
||||||
|
superclass for the appropriate documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
XLM_ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFXLMRobertaForMultipleChoice(TFRobertaForMultipleChoice):
|
||||||
|
"""
|
||||||
|
This class overrides :class:`~transformers.TFRobertaForMultipleChoice`. Please check the
|
||||||
|
superclass for the appropriate documentation alongside usage examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = XLMRobertaConfig
|
||||||
|
@ -23,11 +23,15 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_xlnet import XLNetConfig
|
from .configuration_xlnet import XLNetConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
|
TFMultipleChoiceLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
|
TFQuestionAnsweringLoss,
|
||||||
|
TFSequenceClassificationLoss,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
|
TFTokenClassificationLoss,
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@ -938,7 +942,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
|
|||||||
the pooled output) e.g. for GLUE tasks. """,
|
the pooled output) e.g. for GLUE tasks. """,
|
||||||
XLNET_START_DOCSTRING,
|
XLNET_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
|
class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -952,8 +956,28 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
mems=None,
|
||||||
|
perm_mask=None,
|
||||||
|
target_mapping=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
input_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=True,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the sequence classification/regression loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||||
|
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
||||||
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
|
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
|
||||||
@ -981,12 +1005,24 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||||
model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
|
model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
||||||
logits = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
mems=mems,
|
||||||
|
perm_mask=perm_mask,
|
||||||
|
target_mapping=target_mapping,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
output = self.sequence_summary(output)
|
output = self.sequence_summary(output)
|
||||||
@ -994,7 +1030,159 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
|
|||||||
|
|
||||||
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
return outputs # return logits, (mems), (hidden states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""XLNET Model with a multiple choice classification head on top (a linear layer on top of
|
||||||
|
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
||||||
|
XLNET_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||||
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
super().__init__(config, *inputs, **kwargs)
|
||||||
|
|
||||||
|
self.transformer = TFXLNetMainLayer(config, name="transformer")
|
||||||
|
self.sequence_summary = TFSequenceSummary(
|
||||||
|
config, initializer_range=config.initializer_range, name="sequence_summary"
|
||||||
|
)
|
||||||
|
self.logits_proj = tf.keras.layers.Dense(
|
||||||
|
1, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
token_type_ids=None,
|
||||||
|
input_mask=None,
|
||||||
|
attention_mask=None,
|
||||||
|
mems=None,
|
||||||
|
perm_mask=None,
|
||||||
|
target_mapping=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=True,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the multiple choice classification loss.
|
||||||
|
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
||||||
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||||
|
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
|
||||||
|
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||||
|
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import XLNetTokenizer, TFXLNetForMultipleChoice
|
||||||
|
|
||||||
|
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
|
||||||
|
model = TFXLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
|
||||||
|
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||||
|
|
||||||
|
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
|
||||||
|
labels = tf.reshape(tf.constant(1), (-1, 1))
|
||||||
|
outputs = model(input_ids, labels=labels)
|
||||||
|
|
||||||
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(inputs, (tuple, list)):
|
||||||
|
input_ids = inputs[0]
|
||||||
|
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
|
||||||
|
mems = inputs[2] if len(inputs) > 2 else mems
|
||||||
|
perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
|
||||||
|
target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
|
||||||
|
token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
|
||||||
|
input_mask = inputs[6] if len(inputs) > 6 else input_mask
|
||||||
|
head_mask = inputs[7] if len(inputs) > 7 else head_mask
|
||||||
|
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
|
||||||
|
use_cache = inputs[9] if len(inputs) > 9 else use_cache
|
||||||
|
assert len(inputs) <= 10, "Too many inputs."
|
||||||
|
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||||
|
input_ids = inputs.get("input_ids")
|
||||||
|
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||||
|
mems = inputs.get("mems", mems)
|
||||||
|
perm_mask = inputs.get("perm_mask", perm_mask)
|
||||||
|
target_mapping = inputs.get("target_mapping", target_mapping)
|
||||||
|
token_type_ids = inputs.get("token_type_ids", token_type_ids)
|
||||||
|
input_mask = inputs.get("input_mask", input_mask)
|
||||||
|
head_mask = inputs.get("head_mask", head_mask)
|
||||||
|
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
|
||||||
|
use_cache = inputs.get("use_cache", use_cache)
|
||||||
|
assert len(inputs) <= 10, "Too many inputs."
|
||||||
|
else:
|
||||||
|
input_ids = inputs
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
num_choices = shape_list(input_ids)[1]
|
||||||
|
seq_length = shape_list(input_ids)[2]
|
||||||
|
else:
|
||||||
|
num_choices = shape_list(inputs_embeds)[1]
|
||||||
|
seq_length = shape_list(inputs_embeds)[2]
|
||||||
|
|
||||||
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
|
flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
|
||||||
|
|
||||||
|
flat_inputs = [
|
||||||
|
flat_input_ids,
|
||||||
|
flat_attention_mask,
|
||||||
|
mems,
|
||||||
|
perm_mask,
|
||||||
|
target_mapping,
|
||||||
|
flat_token_type_ids,
|
||||||
|
flat_input_mask,
|
||||||
|
head_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
use_cache,
|
||||||
|
]
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(flat_inputs, training=training)
|
||||||
|
output = transformer_outputs[0]
|
||||||
|
logits = self.sequence_summary(output)
|
||||||
|
logits = self.logits_proj(logits)
|
||||||
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
|
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, reshaped_logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (mems), (hidden states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1002,7 +1190,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
|
|||||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||||
XLNET_START_DOCSTRING,
|
XLNET_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
|
class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificationLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
@ -1012,8 +1200,26 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
|
|||||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
mems=None,
|
||||||
|
perm_mask=None,
|
||||||
|
target_mapping=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
input_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=True,
|
||||||
|
labels=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for computing the token classification loss.
|
||||||
|
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
||||||
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
|
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
|
||||||
@ -1041,19 +1247,36 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
|
||||||
model = TFXLNetForTokenClassification.from_pretrained('xlnet-large-cased')
|
model = TFXLNetForTokenClassification.from_pretrained('xlnet-large-cased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
||||||
outputs = model(input_ids)
|
labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
||||||
scores = outputs[0]
|
outputs = model(input_ids, labels=labels)
|
||||||
|
loss, scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
mems=mems,
|
||||||
|
perm_mask=perm_mask,
|
||||||
|
target_mapping=target_mapping,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
|
|
||||||
logits = self.classifier(output)
|
logits = self.classifier(output)
|
||||||
|
|
||||||
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
return outputs # return logits, (mems), (hidden states), (attentions)
|
if labels is not None:
|
||||||
|
loss = self.compute_loss(labels, logits)
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -1061,7 +1284,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel):
|
|||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
XLNET_START_DOCSTRING,
|
XLNET_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
|
class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnsweringLoss):
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFXLNetMainLayer(config, name="transformer")
|
self.transformer = TFXLNetMainLayer(config, name="transformer")
|
||||||
@ -1070,8 +1293,35 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
def call(self, inputs, **kwargs):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
mems=None,
|
||||||
|
perm_mask=None,
|
||||||
|
target_mapping=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
input_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=True,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
cls_index=None,
|
||||||
|
p_mask=None,
|
||||||
|
is_impossible=None,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
|
start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
|
||||||
loss (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
loss (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||||
@ -1103,12 +1353,27 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
|
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
|
||||||
model = TFXLNetForQuestionAnsweringSimple.from_pretrained('xlnet-base-cased')
|
model = TFXLNetForQuestionAnsweringSimple.from_pretrained('xlnet-base-cased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
outputs = model(input_ids)
|
input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
|
||||||
start_scores, end_scores = outputs[:2]
|
start_scores, end_scores = model(input_dict)
|
||||||
|
|
||||||
|
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
||||||
|
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
transformer_outputs = self.transformer(inputs, **kwargs)
|
transformer_outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
mems=mems,
|
||||||
|
perm_mask=perm_mask,
|
||||||
|
target_mapping=target_mapping,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
sequence_output = transformer_outputs[0]
|
sequence_output = transformer_outputs[0]
|
||||||
|
|
||||||
@ -1121,7 +1386,13 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
|
|||||||
1:
|
1:
|
||||||
] # Keep mems, hidden states, attentions if there are in it
|
] # Keep mems, hidden states, attentions if there are in it
|
||||||
|
|
||||||
return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions)
|
if start_positions is not None and end_positions is not None:
|
||||||
|
labels = {"start_position": start_positions}
|
||||||
|
labels["end_position"] = end_positions
|
||||||
|
loss = self.compute_loss(labels, outputs[:2])
|
||||||
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
|
@ -58,27 +58,41 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(init_lr, num_train_steps, num_warmup_steps, end_lr=0.0, optimizer_type="adamw"):
|
def create_optimizer(
|
||||||
|
init_lr,
|
||||||
|
num_train_steps,
|
||||||
|
num_warmup_steps,
|
||||||
|
min_lr_ratio=0.0,
|
||||||
|
adam_epsilon=1e-8,
|
||||||
|
weight_decay_rate=0.0,
|
||||||
|
include_in_weight_decay=None,
|
||||||
|
):
|
||||||
"""Creates an optimizer with learning rate schedule."""
|
"""Creates an optimizer with learning rate schedule."""
|
||||||
# Implements linear decay of the learning rate.
|
# Implements linear decay of the learning rate.
|
||||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||||
initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=end_lr,
|
initial_learning_rate=init_lr,
|
||||||
|
decay_steps=num_train_steps - num_warmup_steps,
|
||||||
|
end_learning_rate=init_lr * min_lr_ratio,
|
||||||
)
|
)
|
||||||
if num_warmup_steps:
|
if num_warmup_steps:
|
||||||
lr_schedule = WarmUp(
|
lr_schedule = WarmUp(
|
||||||
initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps,
|
initial_learning_rate=init_lr, decay_schedule_fn=lr_schedule, warmup_steps=num_warmup_steps,
|
||||||
)
|
)
|
||||||
|
if weight_decay_rate > 0.0:
|
||||||
optimizer = AdamWeightDecay(
|
optimizer = AdamWeightDecay(
|
||||||
learning_rate=lr_schedule,
|
learning_rate=lr_schedule,
|
||||||
weight_decay_rate=0.01,
|
weight_decay_rate=weight_decay_rate,
|
||||||
beta_1=0.9,
|
beta_1=0.9,
|
||||||
beta_2=0.999,
|
beta_2=0.999,
|
||||||
epsilon=1e-6,
|
epsilon=adam_epsilon,
|
||||||
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||||
)
|
include_in_weight_decay=include_in_weight_decay,
|
||||||
|
)
|
||||||
return optimizer
|
else:
|
||||||
|
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=adam_epsilon)
|
||||||
|
# We return the optimizer and the LR scheduler in order to better track the
|
||||||
|
# evolution of the LR independently of the optimizer.
|
||||||
|
return optimizer, lr_schedule
|
||||||
|
|
||||||
|
|
||||||
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||||
|
@ -3,12 +3,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Dict, Optional
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, shape_list
|
from .modeling_tf_utils import TFPreTrainedModel
|
||||||
from .optimization_tf import GradientAccumulator, create_optimizer
|
from .optimization_tf import GradientAccumulator, create_optimizer
|
||||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput
|
from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
@ -20,13 +20,14 @@ logger = logging.getLogger(__name__)
|
|||||||
class TFTrainer:
|
class TFTrainer:
|
||||||
model: TFPreTrainedModel
|
model: TFPreTrainedModel
|
||||||
args: TFTrainingArguments
|
args: TFTrainingArguments
|
||||||
# something similar to a PT Dataset.
|
|
||||||
# This is just temporary before to have
|
|
||||||
# a framework-agnostic approach for datasets.
|
|
||||||
train_dataset: Optional[tf.data.Dataset]
|
train_dataset: Optional[tf.data.Dataset]
|
||||||
eval_dataset: Optional[tf.data.Dataset]
|
eval_dataset: Optional[tf.data.Dataset]
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
|
||||||
prediction_loss_only: bool
|
prediction_loss_only: bool
|
||||||
|
tb_writer: Optional[tf.summary.SummaryWriter] = None
|
||||||
|
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None
|
||||||
|
global_step: Optional[int] = None
|
||||||
|
epoch: Optional[float] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -36,6 +37,8 @@ class TFTrainer:
|
|||||||
eval_dataset: Optional[tf.data.Dataset] = None,
|
eval_dataset: Optional[tf.data.Dataset] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
prediction_loss_only=False,
|
prediction_loss_only=False,
|
||||||
|
tb_writer: Optional[tf.summary.SummaryWriter] = None,
|
||||||
|
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.args = args
|
self.args = args
|
||||||
@ -43,120 +46,73 @@ class TFTrainer:
|
|||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.prediction_loss_only = prediction_loss_only
|
self.prediction_loss_only = prediction_loss_only
|
||||||
|
self.optimizers = optimizers
|
||||||
self.gradient_accumulator = GradientAccumulator()
|
self.gradient_accumulator = GradientAccumulator()
|
||||||
|
|
||||||
self._setup_training()
|
if tb_writer is not None:
|
||||||
|
self.tb_writer = tb_writer
|
||||||
def _setup_training(self) -> None:
|
|
||||||
"""
|
|
||||||
Setup the different steps to train a model:
|
|
||||||
- check if all the data are given
|
|
||||||
- create the proper strategy
|
|
||||||
- create the features
|
|
||||||
- prepare the model settings
|
|
||||||
"""
|
|
||||||
self._prepare_dataset()
|
|
||||||
|
|
||||||
with self.args.strategy.scope():
|
|
||||||
self._create_optimizer()
|
|
||||||
_ = self.optimizer.iterations
|
|
||||||
self._set_loss_and_metric()
|
|
||||||
self._create_checkpoint_manager()
|
|
||||||
self._create_summary_writer()
|
|
||||||
|
|
||||||
def _set_loss_and_metric(self) -> None:
|
|
||||||
"""
|
|
||||||
Create the training loss and metric with their name. Allowed names are those listed
|
|
||||||
in the Tensorflow documentation and those contained in the transformers library.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.loss = tf.keras.losses.get(
|
|
||||||
{
|
|
||||||
"class_name": self.args.loss_name,
|
|
||||||
"config": {"from_logits": True, "reduction": tf.keras.losses.Reduction.NONE},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
self.loss = tf.keras.losses.get(
|
|
||||||
{"class_name": self.args.loss_name, "config": {"reduction": tf.keras.losses.Reduction.NONE}}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_summary_writer(self) -> None:
|
|
||||||
"""
|
|
||||||
Create a summary writer to be able to read the logs in Tensorboard.
|
|
||||||
"""
|
|
||||||
self.writer = tf.summary.create_file_writer(self.args.logging_dir)
|
|
||||||
|
|
||||||
def _prepare_dataset(self) -> None:
|
|
||||||
"""
|
|
||||||
Prepare the training, validation and test data.
|
|
||||||
"""
|
|
||||||
if self.train_dataset is not None:
|
|
||||||
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
|
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
|
||||||
self.train_steps = self.args.max_steps
|
|
||||||
else:
|
|
||||||
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
|
|
||||||
|
|
||||||
self.train_dataset = (
|
|
||||||
self.train_dataset.cache()
|
|
||||||
.shuffle(self.num_train_examples)
|
|
||||||
.batch(self.args.train_batch_size)
|
|
||||||
.prefetch(tf.data.experimental.AUTOTUNE)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
|
||||||
self.train_dataset = self.train_dataset.repeat(-1)
|
|
||||||
|
|
||||||
self.train_dataset = self.args.strategy.experimental_distribute_dataset(self.train_dataset)
|
|
||||||
else:
|
else:
|
||||||
self.train_steps = 0
|
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
|
||||||
|
|
||||||
if self.eval_dataset is not None:
|
def get_train_tfdataset(self) -> tf.data.Dataset:
|
||||||
self.eval_dataset = (
|
if self.train_dataset is None:
|
||||||
self.eval_dataset.batch(self.args.eval_batch_size).cache().prefetch(tf.data.experimental.AUTOTUNE)
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
)
|
|
||||||
self.eval_dataset = self.args.strategy.experimental_distribute_dataset(self.eval_dataset)
|
|
||||||
|
|
||||||
def _create_optimizer(self) -> None:
|
self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy()
|
||||||
"""
|
|
||||||
Create the training optimizer with its name. Allowed names are those listed
|
if self.args.max_steps > 0:
|
||||||
in the Tensorflow documentation and those contained in the transformers library.
|
self.train_steps = self.args.max_steps
|
||||||
"""
|
|
||||||
if self.args.optimizer_name == "adamw":
|
|
||||||
self.optimizer = create_optimizer(
|
|
||||||
self.args.learning_rate, self.train_steps, self.args.warmup_steps, self.args.end_lr
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size)
|
||||||
self.optimizer = tf.keras.optimizers.get(
|
|
||||||
{
|
|
||||||
"class_name": self.args.optimizer_name,
|
|
||||||
"config": {"learning_rate": self.args.learning_rate, "epsilon": self.args.adam_epsilon},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
# This is for the case where the optimizer is not Adam-like such as SGD
|
|
||||||
self.optimizer = tf.keras.optimizers.get(
|
|
||||||
{"class_name": self.args.optimizer_name, "config": {"learning_rate": self.args.learning_rate}}
|
|
||||||
)
|
|
||||||
logger.info("Created an/a {} optimizer".format(self.args.optimizer_name))
|
|
||||||
|
|
||||||
def _create_checkpoint_manager(self, max_to_keep: int = 5, load_model: bool = True) -> None:
|
ds = (
|
||||||
|
self.train_dataset.cache()
|
||||||
|
.shuffle(self.num_train_examples)
|
||||||
|
.batch(self.args.train_batch_size)
|
||||||
|
.prefetch(tf.data.experimental.AUTOTUNE)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.max_steps > 0:
|
||||||
|
self.train_dataset = self.train_dataset.repeat(-1)
|
||||||
|
|
||||||
|
return self.args.strategy.experimental_distribute_dataset(ds)
|
||||||
|
|
||||||
|
def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
|
||||||
|
if eval_dataset is None and self.eval_dataset is None:
|
||||||
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||||
|
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
ds = eval_dataset.cache().batch(self.args.eval_batch_size).prefetch(tf.data.experimental.AUTOTUNE)
|
||||||
|
|
||||||
|
return self.args.strategy.experimental_distribute_dataset(ds)
|
||||||
|
|
||||||
|
def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset:
|
||||||
|
ds = test_dataset.batch(self.args.eval_batch_size)
|
||||||
|
|
||||||
|
return self.args.strategy.experimental_distribute_dataset(ds)
|
||||||
|
|
||||||
|
def get_optimizers(
|
||||||
|
self,
|
||||||
|
) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]:
|
||||||
"""
|
"""
|
||||||
Create a checkpoint manager in order to be able to make the training
|
Setup the optimizer and the learning rate scheduler.
|
||||||
fault-tolerant.
|
|
||||||
Args:
|
We provide a reasonable default that works well.
|
||||||
max_to_keep: the maximum number of checkpoints to keep in the checkpoint path.
|
If you want to use something else, you can pass a tuple in the Trainer's init,
|
||||||
load_model: if we want to start the training from the latest checkpoint.
|
or override this method in a subclass.
|
||||||
"""
|
"""
|
||||||
ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
|
if self.optimizers is not None:
|
||||||
|
return self.optimizers
|
||||||
|
|
||||||
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=max_to_keep)
|
optimizer, scheduler = create_optimizer(
|
||||||
|
self.args.learning_rate,
|
||||||
|
self.train_steps,
|
||||||
|
self.args.warmup_steps,
|
||||||
|
adam_epsilon=self.args.adam_epsilon,
|
||||||
|
weight_decay_rate=self.args.weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
if load_model:
|
return optimizer, scheduler
|
||||||
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
|
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def _evaluate_steps(self, per_replica_features, per_replica_labels):
|
def _evaluate_steps(self, per_replica_features, per_replica_labels):
|
||||||
@ -182,6 +138,14 @@ class TFTrainer:
|
|||||||
def _prediction_loop(
|
def _prediction_loop(
|
||||||
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
|
self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None
|
||||||
) -> PredictionOutput:
|
) -> PredictionOutput:
|
||||||
|
"""
|
||||||
|
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||||
|
|
||||||
|
Works both with or without labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
|
||||||
|
|
||||||
logger.info("***** Running %s *****", description)
|
logger.info("***** Running %s *****", description)
|
||||||
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
logger.info(" Batch size = %d", self.args.eval_batch_size)
|
||||||
|
|
||||||
@ -196,6 +160,12 @@ class TFTrainer:
|
|||||||
loss = tf.reduce_mean(loss)
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
if not prediction_loss_only:
|
if not prediction_loss_only:
|
||||||
|
if isinstance(logits, tuple):
|
||||||
|
logits = logits[0]
|
||||||
|
|
||||||
|
if isinstance(labels, tuple):
|
||||||
|
labels = labels[0]
|
||||||
|
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
for val in logits.values:
|
for val in logits.values:
|
||||||
if preds is None:
|
if preds is None:
|
||||||
@ -240,10 +210,9 @@ class TFTrainer:
|
|||||||
"""
|
"""
|
||||||
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
|
||||||
"""
|
"""
|
||||||
if eval_dataset is None:
|
eval_ds = self.get_eval_tfdataset(eval_dataset)
|
||||||
eval_dataset = self.eval_dataset
|
|
||||||
|
|
||||||
output = self._prediction_loop(eval_dataset, description="Evaluation")
|
output = self._prediction_loop(eval_ds, description="Evaluation")
|
||||||
|
|
||||||
return output.metrics
|
return output.metrics
|
||||||
|
|
||||||
@ -251,12 +220,25 @@ class TFTrainer:
|
|||||||
"""
|
"""
|
||||||
Train method to train the model.
|
Train method to train the model.
|
||||||
"""
|
"""
|
||||||
|
train_ds = self.get_train_tfdataset()
|
||||||
|
|
||||||
if self.args.debug:
|
if self.args.debug:
|
||||||
tf.summary.trace_on(graph=True, profiler=True)
|
tf.summary.trace_on(graph=True, profiler=True)
|
||||||
|
|
||||||
self.gradient_accumulator.reset()
|
self.gradient_accumulator.reset()
|
||||||
|
|
||||||
iterations = self.optimizer.iterations
|
with self.args.strategy.scope():
|
||||||
|
optimizer, lr_scheduler = self.get_optimizers()
|
||||||
|
iterations = optimizer.iterations
|
||||||
|
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model)
|
||||||
|
self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, PREFIX_CHECKPOINT_DIR, max_to_keep=5)
|
||||||
|
|
||||||
|
if self.model.ckpt_manager.latest_checkpoint:
|
||||||
|
logger.info(
|
||||||
|
"Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial()
|
||||||
|
|
||||||
if iterations.numpy() > 0:
|
if iterations.numpy() > 0:
|
||||||
logger.info("Start the training from the last checkpoint")
|
logger.info("Start the training from the last checkpoint")
|
||||||
@ -268,21 +250,30 @@ class TFTrainer:
|
|||||||
|
|
||||||
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
|
epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs
|
||||||
|
|
||||||
|
if self.args.fp16:
|
||||||
|
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
|
||||||
|
tf.keras.mixed_precision.experimental.set_policy(policy)
|
||||||
|
|
||||||
|
with self.tb_writer.as_default():
|
||||||
|
tf.summary.text("args", self.args.to_json_string())
|
||||||
|
|
||||||
|
self.tb_writer.flush()
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", self.num_train_examples)
|
logger.info(" Num examples = %d", self.num_train_examples)
|
||||||
logger.info(" Num Epochs = %d", epochs)
|
logger.info(" Num Epochs = %d", epochs)
|
||||||
logger.info(" Total optimization steps = %d", self.train_steps)
|
logger.info(" Total optimization steps = %d", self.train_steps)
|
||||||
|
|
||||||
for epoch in range(start_epoch, int(epochs + 1)):
|
for epoch in range(start_epoch, int(epochs + 1)):
|
||||||
for training_loss in self._training_steps():
|
for training_loss in self._training_steps(train_ds, optimizer):
|
||||||
step = iterations.numpy()
|
step = iterations.numpy()
|
||||||
|
|
||||||
if self.args.debug:
|
if self.args.debug:
|
||||||
with self.writer.as_default():
|
with self.tb_writer.as_default():
|
||||||
tf.summary.scalar("loss", training_loss, step=step)
|
tf.summary.scalar("loss", training_loss, step=step)
|
||||||
|
|
||||||
if step == 1 and self.args.debug:
|
if step == 1 and self.args.debug:
|
||||||
with self.writer.as_default():
|
with self.tb_writer.as_default():
|
||||||
tf.summary.trace_export(name="training", step=step, profiler_outdir=self.args.logging_dir)
|
tf.summary.trace_export(name="training", step=step, profiler_outdir=self.args.logging_dir)
|
||||||
|
|
||||||
if self.args.evaluate_during_training and step % self.args.eval_steps == 0:
|
if self.args.evaluate_during_training and step % self.args.eval_steps == 0:
|
||||||
@ -293,17 +284,16 @@ class TFTrainer:
|
|||||||
eval_key = "eval_{}".format(key)
|
eval_key = "eval_{}".format(key)
|
||||||
logs[eval_key] = value
|
logs[eval_key] = value
|
||||||
|
|
||||||
if callable(self.optimizer.learning_rate):
|
logs["learning_rate"] = lr_scheduler(step).numpy()
|
||||||
logs["learning_rate"] = self.optimizer.learning_rate(step).numpy()
|
|
||||||
else:
|
|
||||||
logs["learning_rate"] = self.optimizer.learning_rate.numpy()
|
|
||||||
|
|
||||||
logger.info("Epoch {} Step {} Validation Metrics {}".format(epoch, step, logs))
|
logger.info("Epoch {} Step {} Validation Metrics {}".format(epoch, step, logs))
|
||||||
|
|
||||||
with self.writer.as_default():
|
with self.tb_writer.as_default():
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
tf.summary.scalar(k, v, step=step)
|
tf.summary.scalar(k, v, step=step)
|
||||||
|
|
||||||
|
self.tb_writer.flush()
|
||||||
|
|
||||||
if step % self.args.logging_steps == 0:
|
if step % self.args.logging_steps == 0:
|
||||||
logger.info("Epoch {} Step {} Train Loss {:.4f}".format(epoch, step, training_loss.numpy()))
|
logger.info("Epoch {} Step {} Train Loss {:.4f}".format(epoch, step, training_loss.numpy()))
|
||||||
|
|
||||||
@ -314,21 +304,21 @@ class TFTrainer:
|
|||||||
if step % self.train_steps == 0:
|
if step % self.train_steps == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
def _training_steps(self):
|
def _training_steps(self, ds, optimizer):
|
||||||
"""
|
"""
|
||||||
Returns a generator over training steps (i.e. parameters update).
|
Returns a generator over training steps (i.e. parameters update).
|
||||||
"""
|
"""
|
||||||
for i, loss in enumerate(self._accumulate_next_gradients()):
|
for i, loss in enumerate(self._accumulate_next_gradients(ds)):
|
||||||
if i % self.args.gradient_accumulation_steps == 0:
|
if i % self.args.gradient_accumulation_steps == 0:
|
||||||
self._apply_gradients()
|
self._apply_gradients(optimizer)
|
||||||
yield loss
|
yield loss
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def _apply_gradients(self):
|
def _apply_gradients(self, optimizer):
|
||||||
"""Applies the gradients (cross-replica)."""
|
"""Applies the gradients (cross-replica)."""
|
||||||
self.args.strategy.experimental_run_v2(self._step)
|
self.args.strategy.experimental_run_v2(self._step, args=(optimizer,))
|
||||||
|
|
||||||
def _step(self):
|
def _step(self, optimizer):
|
||||||
"""Applies gradients and resets accumulation."""
|
"""Applies gradients and resets accumulation."""
|
||||||
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
|
gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync
|
||||||
gradients = [
|
gradients = [
|
||||||
@ -336,12 +326,12 @@ class TFTrainer:
|
|||||||
]
|
]
|
||||||
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
|
gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients]
|
||||||
|
|
||||||
self.optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables)))
|
||||||
self.gradient_accumulator.reset()
|
self.gradient_accumulator.reset()
|
||||||
|
|
||||||
def _accumulate_next_gradients(self):
|
def _accumulate_next_gradients(self, ds):
|
||||||
"""Accumulates the gradients from the next element in dataset."""
|
"""Accumulates the gradients from the next element in dataset."""
|
||||||
iterator = iter(self.train_dataset)
|
iterator = iter(ds)
|
||||||
|
|
||||||
@tf.function
|
@tf.function
|
||||||
def _accumulate_next():
|
def _accumulate_next():
|
||||||
@ -388,23 +378,10 @@ class TFTrainer:
|
|||||||
labels: the batched labels.
|
labels: the batched labels.
|
||||||
training: run the model in training mode or not
|
training: run the model in training mode or not
|
||||||
"""
|
"""
|
||||||
if self.args.mode == "text-classification" or self.args.mode == "token-classification":
|
if isinstance(labels, (dict)):
|
||||||
logits = self.model(features, training=training)[0]
|
loss, logits = self.model(features, training=training, **labels)[:2]
|
||||||
else:
|
else:
|
||||||
logits = self.model(features, training=training)
|
loss, logits = self.model(features, labels=labels, training=training)[:2]
|
||||||
|
|
||||||
if self.args.mode == "token-classification":
|
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -1
|
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
|
||||||
loss = self.loss(labels, reduced_logits)
|
|
||||||
elif self.args.mode == "question-answering":
|
|
||||||
start_loss = self.loss(labels["start_position"], logits[0])
|
|
||||||
end_loss = self.loss(labels["end_position"], logits[1])
|
|
||||||
loss = (start_loss + end_loss) / 2.0
|
|
||||||
else:
|
|
||||||
loss = self.loss(labels, logits)
|
|
||||||
|
|
||||||
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
||||||
|
|
||||||
return loss, logits
|
return loss, logits
|
||||||
@ -418,19 +395,24 @@ class TFTrainer:
|
|||||||
test_dataset: something similar to a PT Dataset. This is just
|
test_dataset: something similar to a PT Dataset. This is just
|
||||||
temporary before to have a framework-agnostic approach for datasets.
|
temporary before to have a framework-agnostic approach for datasets.
|
||||||
"""
|
"""
|
||||||
test_dataset = test_dataset.batch(self.args.eval_batch_size)
|
test_ds = self.get_test_tfdataset(test_dataset)
|
||||||
test_dataset = self.args.strategy.experimental_distribute_dataset(test_dataset)
|
|
||||||
|
|
||||||
return self._prediction_loop(test_dataset, description="Prediction")
|
return self._prediction_loop(test_ds, description="Prediction")
|
||||||
|
|
||||||
def save_model(self) -> None:
|
def save_model(self, output_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Save the pretrained model and create a Tensorflow saved model.
|
Save the pretrained model and create a Tensorflow saved model.
|
||||||
"""
|
"""
|
||||||
logger.info("Saving model in {}".format(self.args.output_dir))
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
|
|
||||||
|
logger.info("Saving model in {}".format(output_dir))
|
||||||
|
|
||||||
path = os.path.join(self.args.output_dir, "saved_model")
|
path = os.path.join(self.args.output_dir, "saved_model")
|
||||||
|
|
||||||
logger.info("Saving model in {}".format(path))
|
logger.info("Saving model in {}".format(path))
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
if not isinstance(self.model, TFPreTrainedModel):
|
||||||
|
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||||
|
|
||||||
self.model.save_pretrained(self.args.output_dir)
|
self.model.save_pretrained(self.args.output_dir)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
@ -27,6 +28,17 @@ def is_tpu_available():
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def default_logdir() -> str:
|
||||||
|
"""
|
||||||
|
Same default as PyTorch
|
||||||
|
"""
|
||||||
|
import socket
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
|
||||||
|
return os.path.join("runs", current_time + "_" + socket.gethostname())
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments:
|
class TrainingArguments:
|
||||||
"""
|
"""
|
||||||
@ -97,7 +109,7 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
|
||||||
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
|
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||||
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"})
|
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"})
|
||||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||||
|
@ -14,28 +14,9 @@ if is_tf_available():
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TFTrainingArguments(TrainingArguments):
|
class TFTrainingArguments(TrainingArguments):
|
||||||
optimizer_name: str = field(
|
|
||||||
default="adam",
|
|
||||||
metadata={
|
|
||||||
"help": 'Name of a Tensorflow optimizer among "adadelta, adagrad, adam, adamax, ftrl, nadam, rmsprop, sgd, adamw"'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
mode: str = field(
|
|
||||||
default="text-classification",
|
|
||||||
metadata={"help": 'Type of task, one of "text-classification", "token-classification", "question-answering"'},
|
|
||||||
)
|
|
||||||
loss_name: str = field(
|
|
||||||
default="SparseCategoricalCrossentropy",
|
|
||||||
metadata={
|
|
||||||
"help": "Name of a Tensorflow loss. For the list see: https://www.tensorflow.org/api_docs/python/tf/keras/losses"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tpu_name: str = field(
|
tpu_name: str = field(
|
||||||
default=None, metadata={"help": "Name of TPU"},
|
default=None, metadata={"help": "Name of TPU"},
|
||||||
)
|
)
|
||||||
end_lr: float = field(
|
|
||||||
default=0, metadata={"help": "End learning rate for optimizer"},
|
|
||||||
)
|
|
||||||
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
eval_steps: int = field(default=1000, metadata={"help": "Run an evaluation every X steps."})
|
||||||
debug: bool = field(
|
debug: bool = field(
|
||||||
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
|
default=False, metadata={"help": "Activate the trace to record computation graphs and profiling information"}
|
||||||
|
@ -30,7 +30,7 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding
|
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding, TFSharedEmbeddings
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
gpus = tf.config.list_physical_devices("GPU")
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
@ -107,26 +107,45 @@ class TFModelTesterMixin:
|
|||||||
and getattr(module_member, "_keras_serializable", False)
|
and getattr(module_member, "_keras_serializable", False)
|
||||||
)
|
)
|
||||||
for main_layer_class in tf_main_layer_classes:
|
for main_layer_class in tf_main_layer_classes:
|
||||||
main_layer = main_layer_class(config)
|
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
|
||||||
|
if "T5" in main_layer_class.__name__:
|
||||||
|
# Take the same values than in TFT5ModelTester for this shared layer
|
||||||
|
shared = TFSharedEmbeddings(99, 32, name="shared")
|
||||||
|
main_layer = main_layer_class(config, embed_tokens=shared)
|
||||||
|
else:
|
||||||
|
main_layer = main_layer_class(config)
|
||||||
symbolic_inputs = {
|
symbolic_inputs = {
|
||||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
|
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
|
||||||
outputs = model(inputs_dict)
|
outputs = model(inputs_dict)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
filepath = os.path.join(tmpdirname, "keras_model.h5")
|
filepath = os.path.join(tmpdirname, "keras_model.h5")
|
||||||
model.save(filepath)
|
model.save(filepath)
|
||||||
model = tf.keras.models.load_model(
|
if "T5" in main_layer_class.__name__:
|
||||||
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
|
model = tf.keras.models.load_model(
|
||||||
)
|
filepath,
|
||||||
|
custom_objects={
|
||||||
|
main_layer_class.__name__: main_layer_class,
|
||||||
|
"TFSharedEmbeddings": TFSharedEmbeddings,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = tf.keras.models.load_model(
|
||||||
|
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
|
||||||
|
)
|
||||||
assert isinstance(model, tf.keras.Model)
|
assert isinstance(model, tf.keras.Model)
|
||||||
after_outputs = model(inputs_dict)
|
after_outputs = model(inputs_dict)
|
||||||
self.assert_outputs_same(after_outputs, outputs)
|
self.assert_outputs_same(after_outputs, outputs)
|
||||||
|
|
||||||
def assert_outputs_same(self, after_outputs, outputs):
|
def assert_outputs_same(self, after_outputs, outputs):
|
||||||
# Make sure we don't have nans
|
# Make sure we don't have nans
|
||||||
out_1 = after_outputs[0].numpy()
|
if isinstance(after_outputs, tf.Tensor):
|
||||||
|
out_1 = after_outputs.numpy()
|
||||||
|
else:
|
||||||
|
out_1 = after_outputs[0].numpy()
|
||||||
out_2 = outputs[0].numpy()
|
out_2 = outputs[0].numpy()
|
||||||
self.assertEqual(out_1.shape, out_2.shape)
|
self.assertEqual(out_1.shape, out_2.shape)
|
||||||
out_1 = out_1[~np.isnan(out_1)]
|
out_1 = out_1[~np.isnan(out_1)]
|
||||||
@ -269,7 +288,6 @@ class TFModelTesterMixin:
|
|||||||
inputs_keywords = copy.deepcopy(inputs_dict)
|
inputs_keywords = copy.deepcopy(inputs_dict)
|
||||||
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "inputs", None,)
|
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "inputs", None,)
|
||||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||||
|
|
||||||
output_dict = outputs_dict[0].numpy()
|
output_dict = outputs_dict[0].numpy()
|
||||||
output_keywords = outputs_keywords[0].numpy()
|
output_keywords = outputs_keywords[0].numpy()
|
||||||
|
|
||||||
|
54
tests/test_modeling_tf_flaubert.py
Normal file
54
tests/test_modeling_tf_flaubert.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_tf_available
|
||||||
|
|
||||||
|
from .utils import require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
from transformers import TFFlaubertModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFFlaubertModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_output_embeds_base_model(self):
|
||||||
|
model = TFFlaubertModel.from_pretrained("jplu/tf-flaubert-small-cased")
|
||||||
|
|
||||||
|
input_ids = tf.convert_to_tensor(
|
||||||
|
[[0, 158, 735, 2592, 1424, 6727, 82, 1]], dtype=tf.int32,
|
||||||
|
) # "J'aime flaubert !"
|
||||||
|
|
||||||
|
output = model(input_ids)[0]
|
||||||
|
expected_shape = tf.TensorShape((1, 8, 512))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
# compare the actual values for a slice.
|
||||||
|
expected_slice = tf.convert_to_tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[-1.8768773, -1.566555, 0.27072418],
|
||||||
|
[-1.6920038, -0.5873505, 1.9329599],
|
||||||
|
[-2.9563985, -1.6993835, 1.7972052],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=tf.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
55
tests/test_modeling_tf_xlm_roberta.py
Normal file
55
tests/test_modeling_tf_xlm_roberta.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_tf_available
|
||||||
|
|
||||||
|
from .utils import require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
from transformers import TFXLMRobertaModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFFlaubertModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_output_embeds_base_model(self):
|
||||||
|
model = TFXLMRobertaModel.from_pretrained("jplu/tf-xlm-roberta-base")
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"input_ids": tf.convert_to_tensor([[0, 2646, 10269, 83, 99942, 2]], dtype=tf.int32), # "My dog is cute"
|
||||||
|
"attention_mask": tf.convert_to_tensor([[1, 1, 1, 1, 1, 1]], dtype=tf.int32),
|
||||||
|
}
|
||||||
|
|
||||||
|
output = model(features)[0]
|
||||||
|
expected_shape = tf.TensorShape((1, 6, 768))
|
||||||
|
self.assertEqual(output.shape, expected_shape)
|
||||||
|
# compare the actual values for a slice.
|
||||||
|
expected_slice = tf.convert_to_tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0681762, 0.10894451, 0.06772504],
|
||||||
|
[-0.06423668, 0.02366615, 0.04329344],
|
||||||
|
[-0.06057295, 0.09974135, -0.00070584],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
dtype=tf.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
@ -47,7 +47,7 @@ class OptimizationFTest(unittest.TestCase):
|
|||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
accumulator = GradientAccumulator()
|
accumulator = GradientAccumulator()
|
||||||
variable = tf.Variable([4.0, 3.0])
|
variable = tf.Variable([4.0, 3.0])
|
||||||
optimizer = create_optimizer(5e-5, 10, 5)
|
optimizer, _ = create_optimizer(5e-5, 10, 5)
|
||||||
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
|
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
|
||||||
|
|
||||||
def accumulate_on_replica(gradient):
|
def accumulate_on_replica(gradient):
|
||||||
|
Loading…
Reference in New Issue
Block a user