mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pickle auto models (#12654)
* PoC, it pickles! * Remove old method. * Apply to every auto object
This commit is contained in:
parent
379f649434
commit
9b3aab2cce
@ -1938,7 +1938,7 @@ class _LazyModule(ModuleType):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self._name, self._import_structure))
|
||||
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
||||
|
||||
|
||||
def copy_func(f):
|
||||
|
@ -14,8 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""Factory function to build auto-model classes."""
|
||||
|
||||
import types
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import copy_func
|
||||
from ...utils import logging
|
||||
@ -401,12 +399,12 @@ def insert_head_doc(docstring, head_doc=""):
|
||||
)
|
||||
|
||||
|
||||
def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-cased", head_doc=""):
|
||||
def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""):
|
||||
# Create a new class with the right name from the base class
|
||||
new_class = types.new_class(name, (_BaseAutoModelClass,))
|
||||
new_class._model_mapping = model_mapping
|
||||
model_mapping = cls._model_mapping
|
||||
name = cls.__name__
|
||||
class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
|
||||
new_class.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
|
||||
cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
|
||||
|
||||
# Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
|
||||
# have a specific docstrings for them.
|
||||
@ -416,7 +414,7 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
|
||||
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
||||
from_config.__doc__ = from_config_docstring
|
||||
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
|
||||
new_class.from_config = classmethod(from_config)
|
||||
cls.from_config = classmethod(from_config)
|
||||
|
||||
if name.startswith("TF"):
|
||||
from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
|
||||
@ -432,8 +430,8 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
|
||||
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
|
||||
from_pretrained.__doc__ = from_pretrained_docstring
|
||||
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
||||
new_class.from_pretrained = classmethod(from_pretrained)
|
||||
return new_class
|
||||
cls.from_pretrained = classmethod(from_pretrained)
|
||||
return cls
|
||||
|
||||
|
||||
def get_values(model_mapping):
|
||||
|
@ -308,7 +308,7 @@ from ..xlnet.modeling_xlnet import (
|
||||
XLNetLMHeadModel,
|
||||
XLNetModel,
|
||||
)
|
||||
from .auto_factory import auto_class_factory
|
||||
from .auto_factory import _BaseAutoModelClass, auto_class_update
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
BartConfig,
|
||||
@ -780,64 +780,106 @@ MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
)
|
||||
|
||||
|
||||
AutoModel = auto_class_factory("AutoModel", MODEL_MAPPING)
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
|
||||
|
||||
AutoModel = auto_class_update(AutoModel)
|
||||
|
||||
|
||||
class AutoModelForPreTraining(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING
|
||||
|
||||
|
||||
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
|
||||
|
||||
AutoModelForPreTraining = auto_class_factory(
|
||||
"AutoModelForPreTraining", MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
|
||||
)
|
||||
|
||||
# Private on purpose, the public class will add the deprecation warnings.
|
||||
_AutoModelWithLMHead = auto_class_factory(
|
||||
"AutoModelWithLMHead", MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling"
|
||||
class _AutoModelWithLMHead(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING
|
||||
|
||||
|
||||
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
|
||||
|
||||
|
||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
||||
class AutoModelForMaskedLM(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
||||
|
||||
|
||||
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
|
||||
|
||||
|
||||
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
AutoModelForSeq2SeqLM = auto_class_update(
|
||||
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
|
||||
)
|
||||
|
||||
AutoModelForCausalLM = auto_class_factory(
|
||||
"AutoModelForCausalLM", MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
|
||||
|
||||
class AutoModelForSequenceClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForSequenceClassification = auto_class_update(
|
||||
AutoModelForSequenceClassification, head_doc="sequence classification"
|
||||
)
|
||||
|
||||
AutoModelForMaskedLM = auto_class_factory(
|
||||
"AutoModelForMaskedLM", MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||
)
|
||||
|
||||
AutoModelForSeq2SeqLM = auto_class_factory(
|
||||
"AutoModelForSeq2SeqLM",
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
head_doc="sequence-to-sequence language modeling",
|
||||
checkpoint_for_example="t5-base",
|
||||
)
|
||||
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
AutoModelForSequenceClassification = auto_class_factory(
|
||||
"AutoModelForSequenceClassification", MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, head_doc="sequence classification"
|
||||
)
|
||||
|
||||
AutoModelForQuestionAnswering = auto_class_factory(
|
||||
"AutoModelForQuestionAnswering", MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
|
||||
)
|
||||
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
|
||||
|
||||
AutoModelForTableQuestionAnswering = auto_class_factory(
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
|
||||
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
|
||||
AutoModelForTableQuestionAnswering = auto_class_update(
|
||||
AutoModelForTableQuestionAnswering,
|
||||
head_doc="table question answering",
|
||||
checkpoint_for_example="google/tapas-base-finetuned-wtq",
|
||||
)
|
||||
|
||||
AutoModelForTokenClassification = auto_class_factory(
|
||||
"AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
|
||||
|
||||
class AutoModelForTokenClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
|
||||
|
||||
|
||||
class AutoModelForMultipleChoice(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
||||
|
||||
|
||||
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
|
||||
|
||||
|
||||
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
||||
|
||||
|
||||
AutoModelForNextSentencePrediction = auto_class_update(
|
||||
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
||||
)
|
||||
|
||||
AutoModelForMultipleChoice = auto_class_factory(
|
||||
"AutoModelForMultipleChoice", MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
|
||||
)
|
||||
|
||||
AutoModelForNextSentencePrediction = auto_class_factory(
|
||||
"AutoModelForNextSentencePrediction",
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
head_doc="next sentence prediction",
|
||||
)
|
||||
class AutoModelForImageClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
AutoModelForImageClassification = auto_class_factory(
|
||||
"AutoModelForImageClassification", MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, head_doc="image classification"
|
||||
)
|
||||
|
||||
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
||||
|
||||
|
||||
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||
|
@ -73,7 +73,7 @@ from ..roberta.modeling_flax_roberta import (
|
||||
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
|
||||
from .auto_factory import auto_class_factory
|
||||
from .auto_factory import _BaseAutoModelClass, auto_class_update
|
||||
from .configuration_auto import (
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
@ -217,59 +217,89 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
|
||||
|
||||
FlaxAutoModelForImageClassification = auto_class_factory(
|
||||
"FlaxAutoModelForImageClassification",
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
head_doc="image classification modeling",
|
||||
)
|
||||
class FlaxAutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_MAPPING
|
||||
|
||||
FlaxAutoModelForCausalLM = auto_class_factory(
|
||||
"FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
|
||||
)
|
||||
|
||||
FlaxAutoModelForPreTraining = auto_class_factory(
|
||||
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
|
||||
)
|
||||
FlaxAutoModel = auto_class_update(FlaxAutoModel)
|
||||
|
||||
FlaxAutoModelForMaskedLM = auto_class_factory(
|
||||
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||
|
||||
class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
|
||||
|
||||
|
||||
class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
||||
class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
|
||||
|
||||
|
||||
class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForSeq2SeqLM = auto_class_update(
|
||||
FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
|
||||
)
|
||||
|
||||
|
||||
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
head_doc="sequence-to-sequence language modeling",
|
||||
class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForSequenceClassification = auto_class_update(
|
||||
FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
|
||||
)
|
||||
|
||||
FlaxAutoModelForSequenceClassification = auto_class_factory(
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
head_doc="sequence classification",
|
||||
|
||||
class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
|
||||
|
||||
|
||||
class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForTokenClassification = auto_class_update(
|
||||
FlaxAutoModelForTokenClassification, head_doc="token classification"
|
||||
)
|
||||
|
||||
FlaxAutoModelForQuestionAnswering = auto_class_factory(
|
||||
"FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
|
||||
|
||||
class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
|
||||
|
||||
|
||||
class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
||||
|
||||
|
||||
FlaxAutoModelForNextSentencePrediction = auto_class_update(
|
||||
FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
||||
)
|
||||
|
||||
FlaxAutoModelForTokenClassification = auto_class_factory(
|
||||
"FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
|
||||
)
|
||||
|
||||
FlaxAutoModelForMultipleChoice = auto_class_factory(
|
||||
"AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
|
||||
)
|
||||
class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
FlaxAutoModelForNextSentencePrediction = auto_class_factory(
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
head_doc="next sentence prediction",
|
||||
)
|
||||
|
||||
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
head_doc="sequence-to-sequence language modeling",
|
||||
FlaxAutoModelForImageClassification = auto_class_update(
|
||||
FlaxAutoModelForImageClassification, head_doc="image classification"
|
||||
)
|
||||
|
@ -189,7 +189,7 @@ from ..xlnet.modeling_tf_xlnet import (
|
||||
TFXLNetLMHeadModel,
|
||||
TFXLNetModel,
|
||||
)
|
||||
from .auto_factory import auto_class_factory
|
||||
from .auto_factory import _BaseAutoModelClass, auto_class_update
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
BartConfig,
|
||||
@ -487,54 +487,89 @@ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
)
|
||||
|
||||
|
||||
TFAutoModel = auto_class_factory("TFAutoModel", TF_MODEL_MAPPING)
|
||||
class TFAutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_MAPPING
|
||||
|
||||
|
||||
TFAutoModel = auto_class_update(TFAutoModel)
|
||||
|
||||
|
||||
class TFAutoModelForPreTraining(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
|
||||
|
||||
TFAutoModelForPreTraining = auto_class_factory(
|
||||
"TFAutoModelForPreTraining", TF_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
|
||||
)
|
||||
|
||||
# Private on purpose, the public class will add the deprecation warnings.
|
||||
_TFAutoModelWithLMHead = auto_class_factory(
|
||||
"TFAutoModelWithLMHead", TF_MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling"
|
||||
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
|
||||
|
||||
|
||||
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
|
||||
|
||||
|
||||
class TFAutoModelForCausalLM(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
||||
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
|
||||
|
||||
|
||||
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForSeq2SeqLM = auto_class_update(
|
||||
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
|
||||
)
|
||||
|
||||
TFAutoModelForCausalLM = auto_class_factory(
|
||||
"TFAutoModelForCausalLM", TF_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
|
||||
|
||||
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForSequenceClassification = auto_class_update(
|
||||
TFAutoModelForSequenceClassification, head_doc="sequence classification"
|
||||
)
|
||||
|
||||
TFAutoModelForMaskedLM = auto_class_factory(
|
||||
"TFAutoModelForMaskedLM", TF_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||
|
||||
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
|
||||
|
||||
|
||||
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForTokenClassification = auto_class_update(
|
||||
TFAutoModelForTokenClassification, head_doc="token classification"
|
||||
)
|
||||
|
||||
TFAutoModelForSeq2SeqLM = auto_class_factory(
|
||||
"TFAutoModelForSeq2SeqLM",
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
head_doc="sequence-to-sequence language modeling",
|
||||
checkpoint_for_example="t5-base",
|
||||
)
|
||||
|
||||
TFAutoModelForSequenceClassification = auto_class_factory(
|
||||
"TFAutoModelForSequenceClassification",
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
head_doc="sequence classification",
|
||||
)
|
||||
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
||||
|
||||
TFAutoModelForQuestionAnswering = auto_class_factory(
|
||||
"TFAutoModelForQuestionAnswering", TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
|
||||
)
|
||||
|
||||
TFAutoModelForTokenClassification = auto_class_factory(
|
||||
"TFAutoModelForTokenClassification", TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
|
||||
)
|
||||
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
|
||||
|
||||
TFAutoModelForMultipleChoice = auto_class_factory(
|
||||
"TFAutoModelForMultipleChoice", TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
|
||||
)
|
||||
|
||||
TFAutoModelForNextSentencePrediction = auto_class_factory(
|
||||
"TFAutoModelForNextSentencePrediction",
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
head_doc="next sentence prediction",
|
||||
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForNextSentencePrediction = auto_class_update(
|
||||
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user