mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add auto next sentence prediction (#8432)
* Add auto next sentence prediction * Fix style * Add mobilebert next sentence prediction
This commit is contained in:
parent
c314b1fd3b
commit
8551a99232
@ -61,6 +61,7 @@ from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
|
||||
from .modeling_tf_bert import (
|
||||
TFBertForMaskedLM,
|
||||
TFBertForMultipleChoice,
|
||||
TFBertForNextSentencePrediction,
|
||||
TFBertForPreTraining,
|
||||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
@ -120,6 +121,7 @@ from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
||||
from .modeling_tf_mobilebert import (
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForMultipleChoice,
|
||||
TFMobileBertForNextSentencePrediction,
|
||||
TFMobileBertForPreTraining,
|
||||
TFMobileBertForQuestionAnswering,
|
||||
TFMobileBertForSequenceClassification,
|
||||
@ -355,6 +357,13 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
[
|
||||
(BertConfig, TFBertForNextSentencePrediction),
|
||||
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
|
||||
@ -1412,3 +1421,101 @@ class TFAutoModelForMultipleChoice:
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModelForNextSentencePrediction:
|
||||
r"""
|
||||
This is a generic model class that will be instantiated as one of the model classes of the library---with a
|
||||
multiple choice classification head---when created with the when created with the
|
||||
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the
|
||||
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method.
|
||||
|
||||
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"TFAutoModelForNextSentencePrediction is designed to be instantiated "
|
||||
"using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`TFAutoModelForNextSentencePrediction.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
|
||||
def from_config(cls, config):
|
||||
r"""
|
||||
Instantiates one of the model classes of the library---with a next sentence prediction head---from a
|
||||
configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||
model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to
|
||||
load the model weights.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
List options
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
|
||||
>>> # Download configuration from S3 and cache.
|
||||
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
|
||||
>>> model = TFAutoModelForNextSentencePrediction.from_config(config)
|
||||
"""
|
||||
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
|
||||
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
|
||||
@add_start_docstrings(
|
||||
"Instantiate one of the model classes of the library---with a next sentence prediction head---from a "
|
||||
"pretrained model.",
|
||||
TF_AUTO_MODEL_PRETRAINED_DOCSTRING,
|
||||
)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
|
||||
|
||||
>>> # Download model and configuration from S3 and cache.
|
||||
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||
|
||||
>>> # Update configuration during loading
|
||||
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
|
||||
>>> model.config.output_attentions
|
||||
True
|
||||
|
||||
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
||||
>>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json')
|
||||
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
|
||||
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
@ -87,10 +87,8 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
"T5Stack",
|
||||
"TFBertForNextSentencePrediction",
|
||||
"TFFunnelBaseModel",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFMobileBertForNextSentencePrediction",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLMProphetNetDecoder",
|
||||
|
Loading…
Reference in New Issue
Block a user