mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Add auto next sentence prediction (#8432)
* Add auto next sentence prediction * Fix style * Add mobilebert next sentence prediction
This commit is contained in:
parent
c314b1fd3b
commit
8551a99232
@ -61,6 +61,7 @@ from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
|
|||||||
from .modeling_tf_bert import (
|
from .modeling_tf_bert import (
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
TFBertForMultipleChoice,
|
TFBertForMultipleChoice,
|
||||||
|
TFBertForNextSentencePrediction,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
@ -120,6 +121,7 @@ from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
|||||||
from .modeling_tf_mobilebert import (
|
from .modeling_tf_mobilebert import (
|
||||||
TFMobileBertForMaskedLM,
|
TFMobileBertForMaskedLM,
|
||||||
TFMobileBertForMultipleChoice,
|
TFMobileBertForMultipleChoice,
|
||||||
|
TFMobileBertForNextSentencePrediction,
|
||||||
TFMobileBertForPreTraining,
|
TFMobileBertForPreTraining,
|
||||||
TFMobileBertForQuestionAnswering,
|
TFMobileBertForQuestionAnswering,
|
||||||
TFMobileBertForSequenceClassification,
|
TFMobileBertForSequenceClassification,
|
||||||
@ -355,6 +357,13 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(BertConfig, TFBertForNextSentencePrediction),
|
||||||
|
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||||
|
|
||||||
@ -1412,3 +1421,101 @@ class TFAutoModelForMultipleChoice:
|
|||||||
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
", ".join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TFAutoModelForNextSentencePrediction:
|
||||||
|
r"""
|
||||||
|
This is a generic model class that will be instantiated as one of the model classes of the library---with a
|
||||||
|
multiple choice classification head---when created with the when created with the
|
||||||
|
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the
|
||||||
|
:meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method.
|
||||||
|
|
||||||
|
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"TFAutoModelForNextSentencePrediction is designed to be instantiated "
|
||||||
|
"using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`TFAutoModelForNextSentencePrediction.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False)
|
||||||
|
def from_config(cls, config):
|
||||||
|
r"""
|
||||||
|
Instantiates one of the model classes of the library---with a next sentence prediction head---from a
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||||
|
model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to
|
||||||
|
load the model weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
|
||||||
|
List options
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
|
||||||
|
>>> # Download configuration from S3 and cache.
|
||||||
|
>>> config = AutoConfig.from_pretrained('bert-base-uncased')
|
||||||
|
>>> model = TFAutoModelForNextSentencePrediction.from_config(config)
|
||||||
|
"""
|
||||||
|
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
|
||||||
|
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING)
|
||||||
|
@add_start_docstrings(
|
||||||
|
"Instantiate one of the model classes of the library---with a next sentence prediction head---from a "
|
||||||
|
"pretrained model.",
|
||||||
|
TF_AUTO_MODEL_PRETRAINED_DOCSTRING,
|
||||||
|
)
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
|
||||||
|
|
||||||
|
>>> # Download model and configuration from S3 and cache.
|
||||||
|
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
|
>>> # Update configuration during loading
|
||||||
|
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True)
|
||||||
|
>>> model.config.output_attentions
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
|
||||||
|
>>> config = AutoConfig.from_json_file('./pt_model/bert_pt_model_config.json')
|
||||||
|
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys():
|
||||||
|
return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained(
|
||||||
|
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -87,10 +87,8 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"RagSequenceForGeneration",
|
"RagSequenceForGeneration",
|
||||||
"RagTokenForGeneration",
|
"RagTokenForGeneration",
|
||||||
"T5Stack",
|
"T5Stack",
|
||||||
"TFBertForNextSentencePrediction",
|
|
||||||
"TFFunnelBaseModel",
|
"TFFunnelBaseModel",
|
||||||
"TFGPT2DoubleHeadsModel",
|
"TFGPT2DoubleHeadsModel",
|
||||||
"TFMobileBertForNextSentencePrediction",
|
|
||||||
"TFOpenAIGPTDoubleHeadsModel",
|
"TFOpenAIGPTDoubleHeadsModel",
|
||||||
"XLMForQuestionAnswering",
|
"XLMForQuestionAnswering",
|
||||||
"XLMProphetNetDecoder",
|
"XLMProphetNetDecoder",
|
||||||
|
Loading…
Reference in New Issue
Block a user