mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00

* Initial commit * more stash commit * Yet another stash commit * yet more stash commit * Mostly working except for docs / repo consistency * Stop importing model list from torch file * Add TF BLIP models to docs * Add auto classes * Move get_text_features and get_image_features * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/blip/test_modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Use channels_last convolutions in TF (better performance + compatibility) * Remove _shape function * Move multi-line statement to one line in PT + TF * Specify tf.keras.layers instead of importing from it * Remove test_gradient_checkpointing and empty test_training methods * move some multi-line statements to one line * Update docstring for generate * Remove pruned heads set * Remove self.seq_len_dim * Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states * ensure original model follows config in more cases * Skip the same cross-attention tests in the PT tests - didn't realize we did it twice! * Add training args throughout the models and layers * make fixup * Fix docstring for inputs_embeds * Add docstring for is_decoder * Add docstrings to text models * Remove redundant computation * Add unpack_inputs / keras_serializable * Add modeling_tf_blip to doctests * Add config classes for keras serialization * Changes to allow model porting with pt-to-tf * Quick fix to decoder head and test tweaks * Revert an issue with masking the embeddings outputs * Allow missing keys in some equivalence tests (for unused layers) * Add tf-pt equivalence tests back in * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make fixup * Refactor invert_attention_mask out into tf_utils * Re-enable cross-tests on the PT side too --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
655 lines
25 KiB
Python
655 lines
25 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Auto Model class."""
|
|
|
|
|
|
import warnings
|
|
from collections import OrderedDict
|
|
|
|
from ...utils import logging
|
|
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
|
from .configuration_auto import CONFIG_MAPPING_NAMES
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Base model mapping
|
|
("albert", "TFAlbertModel"),
|
|
("bart", "TFBartModel"),
|
|
("bert", "TFBertModel"),
|
|
("blenderbot", "TFBlenderbotModel"),
|
|
("blenderbot-small", "TFBlenderbotSmallModel"),
|
|
("blip", "TFBlipModel"),
|
|
("camembert", "TFCamembertModel"),
|
|
("clip", "TFCLIPModel"),
|
|
("convbert", "TFConvBertModel"),
|
|
("convnext", "TFConvNextModel"),
|
|
("ctrl", "TFCTRLModel"),
|
|
("cvt", "TFCvtModel"),
|
|
("data2vec-vision", "TFData2VecVisionModel"),
|
|
("deberta", "TFDebertaModel"),
|
|
("deberta-v2", "TFDebertaV2Model"),
|
|
("deit", "TFDeiTModel"),
|
|
("distilbert", "TFDistilBertModel"),
|
|
("dpr", "TFDPRQuestionEncoder"),
|
|
("electra", "TFElectraModel"),
|
|
("esm", "TFEsmModel"),
|
|
("flaubert", "TFFlaubertModel"),
|
|
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
|
("gpt-sw3", "TFGPT2Model"),
|
|
("gpt2", "TFGPT2Model"),
|
|
("gptj", "TFGPTJModel"),
|
|
("groupvit", "TFGroupViTModel"),
|
|
("hubert", "TFHubertModel"),
|
|
("layoutlm", "TFLayoutLMModel"),
|
|
("layoutlmv3", "TFLayoutLMv3Model"),
|
|
("led", "TFLEDModel"),
|
|
("longformer", "TFLongformerModel"),
|
|
("lxmert", "TFLxmertModel"),
|
|
("marian", "TFMarianModel"),
|
|
("mbart", "TFMBartModel"),
|
|
("mobilebert", "TFMobileBertModel"),
|
|
("mobilevit", "TFMobileViTModel"),
|
|
("mpnet", "TFMPNetModel"),
|
|
("mt5", "TFMT5Model"),
|
|
("openai-gpt", "TFOpenAIGPTModel"),
|
|
("opt", "TFOPTModel"),
|
|
("pegasus", "TFPegasusModel"),
|
|
("regnet", "TFRegNetModel"),
|
|
("rembert", "TFRemBertModel"),
|
|
("resnet", "TFResNetModel"),
|
|
("roberta", "TFRobertaModel"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
|
|
("roformer", "TFRoFormerModel"),
|
|
("segformer", "TFSegformerModel"),
|
|
("speech_to_text", "TFSpeech2TextModel"),
|
|
("swin", "TFSwinModel"),
|
|
("t5", "TFT5Model"),
|
|
("tapas", "TFTapasModel"),
|
|
("transfo-xl", "TFTransfoXLModel"),
|
|
("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
|
|
("vit", "TFViTModel"),
|
|
("vit_mae", "TFViTMAEModel"),
|
|
("wav2vec2", "TFWav2Vec2Model"),
|
|
("whisper", "TFWhisperModel"),
|
|
("xglm", "TFXGLMModel"),
|
|
("xlm", "TFXLMModel"),
|
|
("xlm-roberta", "TFXLMRobertaModel"),
|
|
("xlnet", "TFXLNetModel"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for pre-training mapping
|
|
("albert", "TFAlbertForPreTraining"),
|
|
("bart", "TFBartForConditionalGeneration"),
|
|
("bert", "TFBertForPreTraining"),
|
|
("camembert", "TFCamembertForMaskedLM"),
|
|
("ctrl", "TFCTRLLMHeadModel"),
|
|
("distilbert", "TFDistilBertForMaskedLM"),
|
|
("electra", "TFElectraForPreTraining"),
|
|
("flaubert", "TFFlaubertWithLMHeadModel"),
|
|
("funnel", "TFFunnelForPreTraining"),
|
|
("gpt-sw3", "TFGPT2LMHeadModel"),
|
|
("gpt2", "TFGPT2LMHeadModel"),
|
|
("layoutlm", "TFLayoutLMForMaskedLM"),
|
|
("lxmert", "TFLxmertForPreTraining"),
|
|
("mobilebert", "TFMobileBertForPreTraining"),
|
|
("mpnet", "TFMPNetForMaskedLM"),
|
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
|
("roberta", "TFRobertaForMaskedLM"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
|
("t5", "TFT5ForConditionalGeneration"),
|
|
("tapas", "TFTapasForMaskedLM"),
|
|
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
|
("vit_mae", "TFViTMAEForPreTraining"),
|
|
("xlm", "TFXLMWithLMHeadModel"),
|
|
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
|
("xlnet", "TFXLNetLMHeadModel"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model with LM heads mapping
|
|
("albert", "TFAlbertForMaskedLM"),
|
|
("bart", "TFBartForConditionalGeneration"),
|
|
("bert", "TFBertForMaskedLM"),
|
|
("camembert", "TFCamembertForMaskedLM"),
|
|
("convbert", "TFConvBertForMaskedLM"),
|
|
("ctrl", "TFCTRLLMHeadModel"),
|
|
("distilbert", "TFDistilBertForMaskedLM"),
|
|
("electra", "TFElectraForMaskedLM"),
|
|
("esm", "TFEsmForMaskedLM"),
|
|
("flaubert", "TFFlaubertWithLMHeadModel"),
|
|
("funnel", "TFFunnelForMaskedLM"),
|
|
("gpt-sw3", "TFGPT2LMHeadModel"),
|
|
("gpt2", "TFGPT2LMHeadModel"),
|
|
("gptj", "TFGPTJForCausalLM"),
|
|
("layoutlm", "TFLayoutLMForMaskedLM"),
|
|
("led", "TFLEDForConditionalGeneration"),
|
|
("longformer", "TFLongformerForMaskedLM"),
|
|
("marian", "TFMarianMTModel"),
|
|
("mobilebert", "TFMobileBertForMaskedLM"),
|
|
("mpnet", "TFMPNetForMaskedLM"),
|
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
|
("rembert", "TFRemBertForMaskedLM"),
|
|
("roberta", "TFRobertaForMaskedLM"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
|
("roformer", "TFRoFormerForMaskedLM"),
|
|
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
|
("t5", "TFT5ForConditionalGeneration"),
|
|
("tapas", "TFTapasForMaskedLM"),
|
|
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
|
("whisper", "TFWhisperForConditionalGeneration"),
|
|
("xlm", "TFXLMWithLMHeadModel"),
|
|
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
|
("xlnet", "TFXLNetLMHeadModel"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Causal LM mapping
|
|
("bert", "TFBertLMHeadModel"),
|
|
("camembert", "TFCamembertForCausalLM"),
|
|
("ctrl", "TFCTRLLMHeadModel"),
|
|
("gpt-sw3", "TFGPT2LMHeadModel"),
|
|
("gpt2", "TFGPT2LMHeadModel"),
|
|
("gptj", "TFGPTJForCausalLM"),
|
|
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
|
("opt", "TFOPTForCausalLM"),
|
|
("rembert", "TFRemBertForCausalLM"),
|
|
("roberta", "TFRobertaForCausalLM"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
|
|
("roformer", "TFRoFormerForCausalLM"),
|
|
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
|
("xglm", "TFXGLMForCausalLM"),
|
|
("xlm", "TFXLMWithLMHeadModel"),
|
|
("xlm-roberta", "TFXLMRobertaForCausalLM"),
|
|
("xlnet", "TFXLNetLMHeadModel"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("deit", "TFDeiTForMaskedImageModeling"),
|
|
("swin", "TFSwinForMaskedImageModeling"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Image-classsification
|
|
("convnext", "TFConvNextForImageClassification"),
|
|
("cvt", "TFCvtForImageClassification"),
|
|
("data2vec-vision", "TFData2VecVisionForImageClassification"),
|
|
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
|
|
("mobilevit", "TFMobileViTForImageClassification"),
|
|
("regnet", "TFRegNetForImageClassification"),
|
|
("resnet", "TFResNetForImageClassification"),
|
|
("segformer", "TFSegformerForImageClassification"),
|
|
("swin", "TFSwinForImageClassification"),
|
|
("vit", "TFViTForImageClassification"),
|
|
]
|
|
)
|
|
|
|
|
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Zero Shot Image Classification mapping
|
|
("blip", "TFBlipModel"),
|
|
("clip", "TFCLIPModel"),
|
|
]
|
|
)
|
|
|
|
|
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Semantic Segmentation mapping
|
|
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
|
|
("mobilevit", "TFMobileViTForSemanticSegmentation"),
|
|
("segformer", "TFSegformerForSemanticSegmentation"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Masked LM mapping
|
|
("albert", "TFAlbertForMaskedLM"),
|
|
("bert", "TFBertForMaskedLM"),
|
|
("camembert", "TFCamembertForMaskedLM"),
|
|
("convbert", "TFConvBertForMaskedLM"),
|
|
("deberta", "TFDebertaForMaskedLM"),
|
|
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
|
("distilbert", "TFDistilBertForMaskedLM"),
|
|
("electra", "TFElectraForMaskedLM"),
|
|
("esm", "TFEsmForMaskedLM"),
|
|
("flaubert", "TFFlaubertWithLMHeadModel"),
|
|
("funnel", "TFFunnelForMaskedLM"),
|
|
("layoutlm", "TFLayoutLMForMaskedLM"),
|
|
("longformer", "TFLongformerForMaskedLM"),
|
|
("mobilebert", "TFMobileBertForMaskedLM"),
|
|
("mpnet", "TFMPNetForMaskedLM"),
|
|
("rembert", "TFRemBertForMaskedLM"),
|
|
("roberta", "TFRobertaForMaskedLM"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
|
|
("roformer", "TFRoFormerForMaskedLM"),
|
|
("tapas", "TFTapasForMaskedLM"),
|
|
("xlm", "TFXLMWithLMHeadModel"),
|
|
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Seq2Seq Causal LM mapping
|
|
("bart", "TFBartForConditionalGeneration"),
|
|
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
|
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
|
("encoder-decoder", "TFEncoderDecoderModel"),
|
|
("led", "TFLEDForConditionalGeneration"),
|
|
("marian", "TFMarianMTModel"),
|
|
("mbart", "TFMBartForConditionalGeneration"),
|
|
("mt5", "TFMT5ForConditionalGeneration"),
|
|
("pegasus", "TFPegasusForConditionalGeneration"),
|
|
("t5", "TFT5ForConditionalGeneration"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
|
("whisper", "TFWhisperForConditionalGeneration"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Sequence Classification mapping
|
|
("albert", "TFAlbertForSequenceClassification"),
|
|
("bart", "TFBartForSequenceClassification"),
|
|
("bert", "TFBertForSequenceClassification"),
|
|
("camembert", "TFCamembertForSequenceClassification"),
|
|
("convbert", "TFConvBertForSequenceClassification"),
|
|
("ctrl", "TFCTRLForSequenceClassification"),
|
|
("deberta", "TFDebertaForSequenceClassification"),
|
|
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
|
("distilbert", "TFDistilBertForSequenceClassification"),
|
|
("electra", "TFElectraForSequenceClassification"),
|
|
("esm", "TFEsmForSequenceClassification"),
|
|
("flaubert", "TFFlaubertForSequenceClassification"),
|
|
("funnel", "TFFunnelForSequenceClassification"),
|
|
("gpt-sw3", "TFGPT2ForSequenceClassification"),
|
|
("gpt2", "TFGPT2ForSequenceClassification"),
|
|
("gptj", "TFGPTJForSequenceClassification"),
|
|
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
|
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
|
|
("longformer", "TFLongformerForSequenceClassification"),
|
|
("mobilebert", "TFMobileBertForSequenceClassification"),
|
|
("mpnet", "TFMPNetForSequenceClassification"),
|
|
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
|
|
("rembert", "TFRemBertForSequenceClassification"),
|
|
("roberta", "TFRobertaForSequenceClassification"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
|
|
("roformer", "TFRoFormerForSequenceClassification"),
|
|
("tapas", "TFTapasForSequenceClassification"),
|
|
("transfo-xl", "TFTransfoXLForSequenceClassification"),
|
|
("xlm", "TFXLMForSequenceClassification"),
|
|
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
|
|
("xlnet", "TFXLNetForSequenceClassification"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Question Answering mapping
|
|
("albert", "TFAlbertForQuestionAnswering"),
|
|
("bert", "TFBertForQuestionAnswering"),
|
|
("camembert", "TFCamembertForQuestionAnswering"),
|
|
("convbert", "TFConvBertForQuestionAnswering"),
|
|
("deberta", "TFDebertaForQuestionAnswering"),
|
|
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
|
|
("distilbert", "TFDistilBertForQuestionAnswering"),
|
|
("electra", "TFElectraForQuestionAnswering"),
|
|
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
|
("funnel", "TFFunnelForQuestionAnswering"),
|
|
("gptj", "TFGPTJForQuestionAnswering"),
|
|
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
|
|
("longformer", "TFLongformerForQuestionAnswering"),
|
|
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
|
("mpnet", "TFMPNetForQuestionAnswering"),
|
|
("rembert", "TFRemBertForQuestionAnswering"),
|
|
("roberta", "TFRobertaForQuestionAnswering"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
|
|
("roformer", "TFRoFormerForQuestionAnswering"),
|
|
("xlm", "TFXLMForQuestionAnsweringSimple"),
|
|
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
|
|
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("layoutlm", "TFLayoutLMForQuestionAnswering"),
|
|
]
|
|
)
|
|
|
|
|
|
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Table Question Answering mapping
|
|
("tapas", "TFTapasForQuestionAnswering"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Token Classification mapping
|
|
("albert", "TFAlbertForTokenClassification"),
|
|
("bert", "TFBertForTokenClassification"),
|
|
("camembert", "TFCamembertForTokenClassification"),
|
|
("convbert", "TFConvBertForTokenClassification"),
|
|
("deberta", "TFDebertaForTokenClassification"),
|
|
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
|
("distilbert", "TFDistilBertForTokenClassification"),
|
|
("electra", "TFElectraForTokenClassification"),
|
|
("esm", "TFEsmForTokenClassification"),
|
|
("flaubert", "TFFlaubertForTokenClassification"),
|
|
("funnel", "TFFunnelForTokenClassification"),
|
|
("layoutlm", "TFLayoutLMForTokenClassification"),
|
|
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
|
|
("longformer", "TFLongformerForTokenClassification"),
|
|
("mobilebert", "TFMobileBertForTokenClassification"),
|
|
("mpnet", "TFMPNetForTokenClassification"),
|
|
("rembert", "TFRemBertForTokenClassification"),
|
|
("roberta", "TFRobertaForTokenClassification"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
|
|
("roformer", "TFRoFormerForTokenClassification"),
|
|
("xlm", "TFXLMForTokenClassification"),
|
|
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
|
|
("xlnet", "TFXLNetForTokenClassification"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Multiple Choice mapping
|
|
("albert", "TFAlbertForMultipleChoice"),
|
|
("bert", "TFBertForMultipleChoice"),
|
|
("camembert", "TFCamembertForMultipleChoice"),
|
|
("convbert", "TFConvBertForMultipleChoice"),
|
|
("distilbert", "TFDistilBertForMultipleChoice"),
|
|
("electra", "TFElectraForMultipleChoice"),
|
|
("flaubert", "TFFlaubertForMultipleChoice"),
|
|
("funnel", "TFFunnelForMultipleChoice"),
|
|
("longformer", "TFLongformerForMultipleChoice"),
|
|
("mobilebert", "TFMobileBertForMultipleChoice"),
|
|
("mpnet", "TFMPNetForMultipleChoice"),
|
|
("rembert", "TFRemBertForMultipleChoice"),
|
|
("roberta", "TFRobertaForMultipleChoice"),
|
|
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
|
|
("roformer", "TFRoFormerForMultipleChoice"),
|
|
("xlm", "TFXLMForMultipleChoice"),
|
|
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
|
|
("xlnet", "TFXLNetForMultipleChoice"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("bert", "TFBertForNextSentencePrediction"),
|
|
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
|
]
|
|
)
|
|
|
|
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
|
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
|
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
|
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
|
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
|
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
|
)
|
|
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
|
)
|
|
|
|
|
|
class TFAutoModel(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_MAPPING
|
|
|
|
|
|
TFAutoModel = auto_class_update(TFAutoModel)
|
|
|
|
|
|
class TFAutoModelForPreTraining(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
|
|
|
|
|
|
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
|
|
|
|
|
|
# Private on purpose, the public class will add the deprecation warnings.
|
|
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
|
|
|
|
|
|
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
|
|
|
|
|
|
class TFAutoModelForCausalLM(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
|
|
|
|
|
|
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
|
|
|
|
|
|
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
|
|
|
|
|
|
TFAutoModelForMaskedImageModeling = auto_class_update(
|
|
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
|
|
)
|
|
|
|
|
|
class TFAutoModelForImageClassification(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
|
|
|
|
|
TFAutoModelForImageClassification = auto_class_update(
|
|
TFAutoModelForImageClassification, head_doc="image classification"
|
|
)
|
|
|
|
|
|
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
|
|
|
|
|
TFAutoModelForZeroShotImageClassification = auto_class_update(
|
|
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
|
)
|
|
|
|
|
|
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
|
|
|
|
|
TF_AutoModelForSemanticSegmentation = auto_class_update(
|
|
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
|
|
)
|
|
|
|
|
|
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
|
|
|
|
|
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
|
|
|
|
|
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
|
|
|
|
|
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
|
|
|
|
|
|
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
|
|
|
|
|
TFAutoModelForSeq2SeqLM = auto_class_update(
|
|
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
|
|
)
|
|
|
|
|
|
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
|
|
|
|
|
TFAutoModelForSequenceClassification = auto_class_update(
|
|
TFAutoModelForSequenceClassification, head_doc="sequence classification"
|
|
)
|
|
|
|
|
|
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
|
|
|
|
|
|
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
TFAutoModelForDocumentQuestionAnswering = auto_class_update(
|
|
TFAutoModelForDocumentQuestionAnswering,
|
|
head_doc="document question answering",
|
|
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
|
|
)
|
|
|
|
|
|
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
TFAutoModelForTableQuestionAnswering = auto_class_update(
|
|
TFAutoModelForTableQuestionAnswering,
|
|
head_doc="table question answering",
|
|
checkpoint_for_example="google/tapas-base-finetuned-wtq",
|
|
)
|
|
|
|
|
|
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
|
|
|
|
|
TFAutoModelForTokenClassification = auto_class_update(
|
|
TFAutoModelForTokenClassification, head_doc="token classification"
|
|
)
|
|
|
|
|
|
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
|
|
|
|
|
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
|
|
|
|
|
|
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
|
|
|
|
|
TFAutoModelForNextSentencePrediction = auto_class_update(
|
|
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
|
)
|
|
|
|
|
|
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
|
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
|
|
|
|
|
TFAutoModelForSpeechSeq2Seq = auto_class_update(
|
|
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
|
)
|
|
|
|
|
|
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
warnings.warn(
|
|
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
|
|
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
|
|
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
|
|
FutureWarning,
|
|
)
|
|
return super().from_config(config)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
warnings.warn(
|
|
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
|
|
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
|
|
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
|
|
FutureWarning,
|
|
)
|
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|