Add auto next sentence prediction (#8432)

* Add auto next sentence prediction

* Fix style

* Add mobilebert next sentence prediction
This commit is contained in:
Julien Plu 2020-11-10 17:11:48 +01:00 committed by GitHub
parent c314b1fd3b
commit 8551a99232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 107 additions and 2 deletions

View File

@ -61,6 +61,7 @@ from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from .modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
@ -120,6 +121,7 @@ from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
TFMobileBertForPreTraining,
TFMobileBertForQuestionAnswering,
TFMobileBertForSequenceClassification,
@ -355,6 +357,13 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
]
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
[
(BertConfig, TFBertForNextSentencePrediction),
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
]
)
TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
@ -1412,3 +1421,101 @@ class TFAutoModelForMultipleChoice:
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
)
)
class TFAutoModelForNextSentencePrediction:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with a
multiple choice classification head---when created with the when created with the
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"TFAutoModelForNextSentencePrediction is designed to be instantiated "
"using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
"`TFAutoModelForNextSentencePrediction.from_config(config)` methods."
)
@classmethod
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
def from_config(cls, config):
r"""
Instantiates one of the model classes of the library---with a next sentence prediction head---from a
configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to
load the model weights.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
List options
Examples::
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download configuration from S3 and cache.
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForNextSentencePrediction.from_config(config)
"""
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)
@classmethod
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
@add_start_docstrings(
"Instantiate one of the model classes of the library---with a next sentence prediction head---from a "
"pretrained model.",
TF_AUTO_MODEL_PRETRAINED_DOCSTRING,
)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Examples::
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download model and configuration from S3 and cache.
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> # Update configuration during loading
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json')
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
)
)

View File

@ -87,10 +87,8 @@ IGNORE_NON_AUTO_CONFIGURED = [
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFBertForNextSentencePrediction",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFMobileBertForNextSentencePrediction",
"TFOpenAIGPTDoubleHeadsModel",
"XLMForQuestionAnswering",
"XLMProphetNetDecoder",