From 9870093f7b31bf774fe6bdfeed5e08f0d4649b07 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 6 Aug 2021 13:12:30 +0200 Subject: [PATCH] [WIP] Disentangle auto modules from other modeling files (#13023) * Initial work * All auto models * All tf auto models * All flax auto models * Tokenizers * Add feature extractors * Fix typos * Fix other typo * Use the right config * Remove old mapping names and update logic in AutoTokenizer * Update check_table * Fix copies and check_repo script * Fix last test * Add back name * clean up * Update template * Update template * Forgot a ) * Use alternative to fixup * Fix TF model template * Address review comments * Address review comments * Style --- .github/workflows/model-templates.yml | 2 +- Makefile | 1 - src/transformers/__init__.py | 8 +- src/transformers/modelcard.py | 6 +- src/transformers/models/__init__.py | 5 + src/transformers/models/auto/auto_factory.py | 84 +- .../models/auto/configuration_auto.py | 472 +++---- .../models/auto/feature_extraction_auto.py | 54 +- src/transformers/models/auto/modeling_auto.py | 1079 ++++++----------- .../models/auto/modeling_flax_auto.py | 238 ++-- .../models/auto/modeling_tf_auto.py | 656 ++++------ .../models/auto/tokenization_auto.py | 457 +++---- src/transformers/models/mbart/__init__.py | 4 - src/transformers/models/mbart50/__init__.py | 42 + .../tokenization_mbart50.py | 0 .../tokenization_mbart50_fast.py | 0 src/transformers/tokenization_utils_base.py | 21 +- src/transformers/trainer.py | 2 +- .../utils/dummy_tokenizers_objects.py | 4 +- .../utils/modeling_auto_mapping.py | 374 ------ ...tf_{{cookiecutter.lowercase_modelname}}.py | 6 +- ...ce_{{cookiecutter.lowercase_modelname}}.py | 102 +- tests/test_pipelines_translation.py | 3 +- utils/check_repo.py | 8 +- utils/check_table.py | 9 +- utils/class_mapping_update.py | 106 -- 26 files changed, 1338 insertions(+), 2405 deletions(-) create mode 100644 src/transformers/models/mbart50/__init__.py rename src/transformers/models/{mbart => mbart50}/tokenization_mbart50.py (100%) rename src/transformers/models/{mbart => mbart50}/tokenization_mbart50_fast.py (100%) delete mode 100644 src/transformers/utils/modeling_auto_mapping.py delete mode 100644 utils/class_mapping_update.py diff --git a/.github/workflows/model-templates.yml b/.github/workflows/model-templates.yml index ab0f7a9aade..83e5b40de4a 100644 --- a/.github/workflows/model-templates.yml +++ b/.github/workflows/model-templates.yml @@ -59,7 +59,7 @@ jobs: - name: Run style changes run: | git fetch origin master:master - make fixup + make style && make quality - name: Failure short reports if: ${{ always() }} diff --git a/Makefile b/Makefile index 4ea50b9d486..11e96f84c99 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,6 @@ deps_table_check_updated: # autogenerating code autogenerate_code: deps_table_update - python utils/class_mapping_update.py # Check that source code meets quality standards diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e776393e37a..2d547e8cccf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -213,6 +213,7 @@ _import_structure = { "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], "models.marian": ["MarianConfig"], "models.mbart": ["MBartConfig"], + "models.mbart50": [], "models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], "models.mmbt": ["MMBTConfig"], "models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], @@ -315,7 +316,7 @@ if is_sentencepiece_available(): _import_structure["models.m2m_100"].append("M2M100Tokenizer") _import_structure["models.marian"].append("MarianTokenizer") _import_structure["models.mbart"].append("MBartTokenizer") - _import_structure["models.mbart"].append("MBart50Tokenizer") + _import_structure["models.mbart50"].append("MBart50Tokenizer") _import_structure["models.mt5"].append("MT5Tokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") @@ -358,7 +359,7 @@ if is_tokenizers_available(): _import_structure["models.longformer"].append("LongformerTokenizerFast") _import_structure["models.lxmert"].append("LxmertTokenizerFast") _import_structure["models.mbart"].append("MBartTokenizerFast") - _import_structure["models.mbart"].append("MBart50TokenizerFast") + _import_structure["models.mbart50"].append("MBart50TokenizerFast") _import_structure["models.mobilebert"].append("MobileBertTokenizerFast") _import_structure["models.mpnet"].append("MPNetTokenizerFast") _import_structure["models.mt5"].append("MT5TokenizerFast") @@ -2021,7 +2022,8 @@ if TYPE_CHECKING: from .models.led import LEDTokenizerFast from .models.longformer import LongformerTokenizerFast from .models.lxmert import LxmertTokenizerFast - from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast + from .models.mbart import MBartTokenizerFast + from .models.mbart50 import MBart50TokenizerFast from .models.mobilebert import MobileBertTokenizerFast from .models.mpnet import MPNetTokenizerFast from .models.mt5 import MT5TokenizerFast diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index bb1b3b840b1..d9a0f6803d8 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -41,9 +41,7 @@ from .file_utils import ( is_tokenizers_available, is_torch_available, ) -from .training_args import ParallelMode -from .utils import logging -from .utils.modeling_auto_mapping import ( +from .models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES, @@ -54,6 +52,8 @@ from .utils.modeling_auto_mapping import ( MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, ) +from .training_args import ParallelMode +from .utils import logging TASK_MAPPING = { diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0d01a7680f3..a0c4524355c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -37,6 +37,7 @@ from . import ( cpm, ctrl, deberta, + deberta_v2, deit, detr, dialogpt, @@ -50,6 +51,8 @@ from . import ( gpt2, gpt_neo, herbert, + hubert, + ibert, layoutlm, led, longformer, @@ -58,6 +61,7 @@ from . import ( m2m_100, marian, mbart, + mbart50, megatron_bert, mmbt, mobilebert, @@ -82,6 +86,7 @@ from . import ( vit, wav2vec2, xlm, + xlm_prophetnet, xlm_roberta, xlnet, ) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 9bb5eaf6ef7..2214a61e77b 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -13,11 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Factory function to build auto-model classes.""" +import importlib +from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ...file_utils import copy_func from ...utils import logging -from .configuration_auto import AutoConfig, replace_list_option_in_docstrings +from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings logger = logging.get_logger(__name__) @@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc="" from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) from_config.__doc__ = from_config_docstring - from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config) + from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) cls.from_config = classmethod(from_config) if name.startswith("TF"): @@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc="" shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) from_pretrained.__doc__ = from_pretrained_docstring - from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained) + from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) cls.from_pretrained = classmethod(from_pretrained) return cls @@ -445,3 +447,79 @@ def get_values(model_mapping): result.append(model) return result + + +def getattribute_from_module(module, attr): + if attr is None: + return None + if isinstance(attr, tuple): + return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): + return getattr(module, attr) + # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the + # object at the top level. + transformers_module = importlib.import_module("transformers") + return getattribute_from_module(transformers_module, attr) + + +class _LazyAutoMapping(OrderedDict): + """ + " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. + + Args: + + - config_mapping: The map model type to config class + - model_mapping: The map model type to model (or tokenizer) class + """ + + def __init__(self, config_mapping, model_mapping): + self._config_mapping = config_mapping + self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} + self._model_mapping = model_mapping + self._modules = {} + + def __getitem__(self, key): + model_type = self._reverse_config_mapping[key.__name__] + if model_type not in self._model_mapping: + raise KeyError(key) + model_name = self._model_mapping[model_type] + return self._load_attr_from_module(model_type, model_name) + + def _load_attr_from_module(self, model_type, attr): + module_name = model_type_to_module_name(model_type) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + return getattribute_from_module(self._modules[module_name], attr) + + def keys(self): + return [ + self._load_attr_from_module(key, name) + for key, name in self._config_mapping.items() + if key in self._model_mapping.keys() + ] + + def values(self): + return [ + self._load_attr_from_module(key, name) + for key, name in self._model_mapping.items() + if key in self._config_mapping.keys() + ] + + def items(self): + return [ + ( + self._load_attr_from_module(key, self._config_mapping[key]), + self._load_attr_from_module(key, self._model_mapping[key]), + ) + for key in self._model_mapping.keys() + if key in self._config_mapping.keys() + ] + + def __iter__(self): + return iter(self._mapping.keys()) + + def __contains__(self, item): + if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: + return False + model_type = self._reverse_config_mapping[item.__name__] + return model_type in self._model_mapping diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index aa7ccaa1632..69750d0c99c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -13,215 +13,140 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Auto Config class. """ - +import importlib import re +import warnings from collections import OrderedDict +from typing import List, Union from ...configuration_utils import PretrainedConfig -from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig -from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig -from ..beit.configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig -from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig -from ..bert_generation.configuration_bert_generation import BertGenerationConfig -from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig -from ..bigbird_pegasus.configuration_bigbird_pegasus import ( - BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, - BigBirdPegasusConfig, -) -from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig -from ..blenderbot_small.configuration_blenderbot_small import ( - BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, - BlenderbotSmallConfig, -) -from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig -from ..canine.configuration_canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig -from ..clip.configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig -from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig -from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig -from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig -from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config -from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig -from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig -from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig -from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig -from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig -from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig -from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig -from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig -from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig -from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config -from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig -from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig -from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig -from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig -from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig -from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig -from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig -from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig -from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config -from ..marian.configuration_marian import MarianConfig -from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig -from ..megatron_bert.configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig -from ..mobilebert.configuration_mobilebert import MobileBertConfig -from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig -from ..mt5.configuration_mt5 import MT5Config -from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig -from ..pegasus.configuration_pegasus import PegasusConfig -from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig -from ..rag.configuration_rag import RagConfig -from ..reformer.configuration_reformer import ReformerConfig -from ..rembert.configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig -from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig -from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig -from ..roformer.configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig -from ..speech_to_text.configuration_speech_to_text import ( - SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, - Speech2TextConfig, -) -from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig -from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config -from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig -from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig -from ..visual_bert.configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig -from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig -from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config -from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig -from ..xlm_prophetnet.configuration_xlm_prophetnet import ( - XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLMProphetNetConfig, -) -from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig -from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig -ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( - (key, value) - for pretrained_map in [ - # Add archive maps here - BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, - REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, - ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, - CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, - BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, - DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, - LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, - DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, - GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, - BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, - MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, - VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, - WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, - M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, - CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - LED_PRETRAINED_CONFIG_ARCHIVE_MAP, - BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, - BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - BART_PRETRAINED_CONFIG_ARCHIVE_MAP, - BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, - MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, - OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, - TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, - GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, - CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, - ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - T5_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, - ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, - LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, - RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, - LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, - DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, - DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, - SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, - IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - ] - for key, value, in pretrained_map.items() -) - - -CONFIG_MAPPING = OrderedDict( +CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here - ("beit", BeitConfig), - ("rembert", RemBertConfig), - ("visual_bert", VisualBertConfig), - ("canine", CanineConfig), - ("roformer", RoFormerConfig), - ("clip", CLIPConfig), - ("bigbird_pegasus", BigBirdPegasusConfig), - ("deit", DeiTConfig), - ("luke", LukeConfig), - ("detr", DetrConfig), - ("gpt_neo", GPTNeoConfig), - ("big_bird", BigBirdConfig), - ("speech_to_text", Speech2TextConfig), - ("vit", ViTConfig), - ("wav2vec2", Wav2Vec2Config), - ("m2m_100", M2M100Config), - ("convbert", ConvBertConfig), - ("led", LEDConfig), - ("blenderbot-small", BlenderbotSmallConfig), - ("retribert", RetriBertConfig), - ("ibert", IBertConfig), - ("mt5", MT5Config), - ("t5", T5Config), - ("mobilebert", MobileBertConfig), - ("distilbert", DistilBertConfig), - ("albert", AlbertConfig), - ("bert-generation", BertGenerationConfig), - ("camembert", CamembertConfig), - ("xlm-roberta", XLMRobertaConfig), - ("pegasus", PegasusConfig), - ("marian", MarianConfig), - ("mbart", MBartConfig), - ("megatron-bert", MegatronBertConfig), - ("mpnet", MPNetConfig), - ("bart", BartConfig), - ("blenderbot", BlenderbotConfig), - ("reformer", ReformerConfig), - ("longformer", LongformerConfig), - ("roberta", RobertaConfig), - ("deberta-v2", DebertaV2Config), - ("deberta", DebertaConfig), - ("flaubert", FlaubertConfig), - ("fsmt", FSMTConfig), - ("squeezebert", SqueezeBertConfig), - ("hubert", HubertConfig), - ("bert", BertConfig), - ("openai-gpt", OpenAIGPTConfig), - ("gpt2", GPT2Config), - ("transfo-xl", TransfoXLConfig), - ("xlnet", XLNetConfig), - ("xlm-prophetnet", XLMProphetNetConfig), - ("prophetnet", ProphetNetConfig), - ("xlm", XLMConfig), - ("ctrl", CTRLConfig), - ("electra", ElectraConfig), - ("encoder-decoder", EncoderDecoderConfig), - ("funnel", FunnelConfig), - ("lxmert", LxmertConfig), - ("dpr", DPRConfig), - ("layoutlm", LayoutLMConfig), - ("rag", RagConfig), - ("tapas", TapasConfig), + ("beit", "BeitConfig"), + ("rembert", "RemBertConfig"), + ("visual_bert", "VisualBertConfig"), + ("canine", "CanineConfig"), + ("roformer", "RoFormerConfig"), + ("clip", "CLIPConfig"), + ("bigbird_pegasus", "BigBirdPegasusConfig"), + ("deit", "DeiTConfig"), + ("luke", "LukeConfig"), + ("detr", "DetrConfig"), + ("gpt_neo", "GPTNeoConfig"), + ("big_bird", "BigBirdConfig"), + ("speech_to_text", "Speech2TextConfig"), + ("vit", "ViTConfig"), + ("wav2vec2", "Wav2Vec2Config"), + ("m2m_100", "M2M100Config"), + ("convbert", "ConvBertConfig"), + ("led", "LEDConfig"), + ("blenderbot-small", "BlenderbotSmallConfig"), + ("retribert", "RetriBertConfig"), + ("ibert", "IBertConfig"), + ("mt5", "MT5Config"), + ("t5", "T5Config"), + ("mobilebert", "MobileBertConfig"), + ("distilbert", "DistilBertConfig"), + ("albert", "AlbertConfig"), + ("bert-generation", "BertGenerationConfig"), + ("camembert", "CamembertConfig"), + ("xlm-roberta", "XLMRobertaConfig"), + ("pegasus", "PegasusConfig"), + ("marian", "MarianConfig"), + ("mbart", "MBartConfig"), + ("megatron-bert", "MegatronBertConfig"), + ("mpnet", "MPNetConfig"), + ("bart", "BartConfig"), + ("blenderbot", "BlenderbotConfig"), + ("reformer", "ReformerConfig"), + ("longformer", "LongformerConfig"), + ("roberta", "RobertaConfig"), + ("deberta-v2", "DebertaV2Config"), + ("deberta", "DebertaConfig"), + ("flaubert", "FlaubertConfig"), + ("fsmt", "FSMTConfig"), + ("squeezebert", "SqueezeBertConfig"), + ("hubert", "HubertConfig"), + ("bert", "BertConfig"), + ("openai-gpt", "OpenAIGPTConfig"), + ("gpt2", "GPT2Config"), + ("transfo-xl", "TransfoXLConfig"), + ("xlnet", "XLNetConfig"), + ("xlm-prophetnet", "XLMProphetNetConfig"), + ("prophetnet", "ProphetNetConfig"), + ("xlm", "XLMConfig"), + ("ctrl", "CTRLConfig"), + ("electra", "ElectraConfig"), + ("encoder-decoder", "EncoderDecoderConfig"), + ("funnel", "FunnelConfig"), + ("lxmert", "LxmertConfig"), + ("dpr", "DPRConfig"), + ("layoutlm", "LayoutLMConfig"), + ("rag", "RagConfig"), + ("tapas", "TapasConfig"), + ] +) + +CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( + [ + # Add archive maps here + ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ] ) @@ -290,14 +215,136 @@ MODEL_NAMES_MAPPING = OrderedDict( ("mpnet", "MPNet"), ("tapas", "TAPAS"), ("hubert", "Hubert"), + ("barthez", "BARThez"), + ("phobert", "PhoBERT"), + ("cpm", "CPM"), + ("bertweet", "Bertweet"), + ("bert-japanese", "BertJapanese"), + ("byt5", "ByT5"), + ("mbart50", "mBART-50"), ] ) +SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")]) -def _get_class_name(model_class): + +def model_type_to_module_name(key): + """Converts a config key to the corresponding module.""" + # Special treatment + if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME: + return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key] + + return key.replace("-", "_") + + +def config_class_to_model_type(config): + """Converts a config class name to the corresponding model type""" + for key, cls in CONFIG_MAPPING_NAMES.items(): + if cls == config: + return key + return None + + +class _LazyConfigMapping(OrderedDict): + """ + A dictionary that lazily load its values when they are requested. + """ + + def __init__(self, mapping): + self._mapping = mapping + self._modules = {} + + def __getitem__(self, key): + if key not in self._mapping: + raise KeyError(key) + value = self._mapping[key] + module_name = model_type_to_module_name(key) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + return getattr(self._modules[module_name], value) + + def keys(self): + return self._mapping.keys() + + def values(self): + return [self[k] for k in self._mapping.keys()] + + def items(self): + return [(k, self[k]) for k in self._mapping.keys()] + + def __iter__(self): + return iter(self._mapping.keys()) + + def __contains__(self, item): + return item in self._mapping + + +CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES) + + +class _LazyLoadAllMappings(OrderedDict): + """ + A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values, + etc.) + + Args: + mapping: The mapping to load. + """ + + def __init__(self, mapping): + self._mapping = mapping + self._initialized = False + self._data = {} + + def _initialize(self): + if self._initialized: + return + warnings.warn( + "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " + "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", + FutureWarning, + ) + + for model_type, map_name in self._mapping.items(): + module_name = model_type_to_module_name(model_type) + module = importlib.import_module(f".{module_name}", "transformers.models") + mapping = getattr(module, map_name) + self._data.update(mapping) + + self._initialized = True + + def __getitem__(self, key): + self._initialize() + return self._data[key] + + def keys(self): + self._initialize() + return self._data.keys() + + def values(self): + self._initialize() + return self._data.values() + + def items(self): + self._initialize() + return self._data.keys() + + def __iter__(self): + self._initialize() + return iter(self._data) + + def __contains__(self, item): + self._initialize() + return item in self._data + + +ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES) + + +def _get_class_name(model_class: Union[str, List[str]]): if isinstance(model_class, (list, tuple)): - return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class]) - return f":class:`~transformers.{model_class.__name__}`" + return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None]) + return f":class:`~transformers.{model_class}`" def _list_model_options(indent, config_to_class=None, use_model_types=True): @@ -306,23 +353,26 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): if use_model_types: if config_to_class is None: model_type_to_name = { - model_type: f":class:`~transformers.{config.__name__}`" - for model_type, config in CONFIG_MAPPING.items() + model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() } else: model_type_to_name = { - model_type: _get_class_name(config_to_class[config]) - for model_type, config in CONFIG_MAPPING.items() - if config in config_to_class + model_type: _get_class_name(model_class) + for model_type, model_class in config_to_class.items() + if model_type in MODEL_NAMES_MAPPING } lines = [ f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)" for model_type in sorted(model_type_to_name.keys()) ] else: - config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()} + config_to_name = { + CONFIG_MAPPING_NAMES[config]: _get_class_name(clas) + for config, clas in config_to_class.items() + if config in CONFIG_MAPPING_NAMES + } config_to_model_name = { - config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items() + config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() } lines = [ f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 6d853a131a1..39ba15f5ac8 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -13,36 +13,43 @@ # See the License for the specific language governing permissions and # limitations under the License. """ AutoFeatureExtractor class. """ - +import importlib import os from collections import OrderedDict -from transformers import BeitFeatureExtractor, DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor - -from ... import BeitConfig, DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config -from ...feature_extraction_utils import FeatureExtractionMixin - # Build the list of all feature extractors +from ...configuration_utils import PretrainedConfig +from ...feature_extraction_utils import FeatureExtractionMixin from ...file_utils import FEATURE_EXTRACTOR_NAME -from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor -from .configuration_auto import AutoConfig, replace_list_option_in_docstrings - - -FEATURE_EXTRACTOR_MAPPING = OrderedDict( - [ - (BeitConfig, BeitFeatureExtractor), - (DeiTConfig, DeiTFeatureExtractor), - (Speech2TextConfig, Speech2TextFeatureExtractor), - (ViTConfig, ViTFeatureExtractor), - (Wav2Vec2Config, Wav2Vec2FeatureExtractor), - ] +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + config_class_to_model_type, + replace_list_option_in_docstrings, ) +FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( + [ + ("beit", "BeitFeatureExtractor"), + ("deit", "DeiTFeatureExtractor"), + ("speech_to_text", "Speech2TextFeatureExtractor"), + ("vit", "ViTFeatureExtractor"), + ("wav2vec2", "Wav2Vec2FeatureExtractor"), + ] +) + +FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES) + + def feature_extractor_class_from_name(class_name: str): - for c in FEATURE_EXTRACTOR_MAPPING.values(): - if c is not None and c.__name__ == class_name: - return c + for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items(): + if class_name in extractors: + break + + module = importlib.import_module(f".{module_name}", "transformers.models") + return getattr(module, class_name) class AutoFeatureExtractor: @@ -60,7 +67,7 @@ class AutoFeatureExtractor: ) @classmethod - @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING) + @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES) def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): r""" Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary. @@ -142,7 +149,8 @@ class AutoFeatureExtractor: kwargs["_from_auto"] = True config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) - if type(config) in FEATURE_EXTRACTOR_MAPPING.keys(): + model_type = config_class_to_model_type(type(config).__name__) + if model_type is not None: return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs) elif "feature_extractor_type" in config_dict: feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"]) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 92b4d132568..ab254c3fa89 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -14,793 +14,454 @@ # limitations under the License. """ Auto Model class. """ - import warnings from collections import OrderedDict from ...utils import logging - -# Add modeling imports here -from ..albert.modeling_albert import ( - AlbertForMaskedLM, - AlbertForMultipleChoice, - AlbertForPreTraining, - AlbertForQuestionAnswering, - AlbertForSequenceClassification, - AlbertForTokenClassification, - AlbertModel, -) -from ..bart.modeling_bart import ( - BartForCausalLM, - BartForConditionalGeneration, - BartForQuestionAnswering, - BartForSequenceClassification, - BartModel, -) -from ..beit.modeling_beit import BeitForImageClassification, BeitModel -from ..bert.modeling_bert import ( - BertForMaskedLM, - BertForMultipleChoice, - BertForNextSentencePrediction, - BertForPreTraining, - BertForQuestionAnswering, - BertForSequenceClassification, - BertForTokenClassification, - BertLMHeadModel, - BertModel, -) -from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder -from ..big_bird.modeling_big_bird import ( - BigBirdForCausalLM, - BigBirdForMaskedLM, - BigBirdForMultipleChoice, - BigBirdForPreTraining, - BigBirdForQuestionAnswering, - BigBirdForSequenceClassification, - BigBirdForTokenClassification, - BigBirdModel, -) -from ..bigbird_pegasus.modeling_bigbird_pegasus import ( - BigBirdPegasusForCausalLM, - BigBirdPegasusForConditionalGeneration, - BigBirdPegasusForQuestionAnswering, - BigBirdPegasusForSequenceClassification, - BigBirdPegasusModel, -) -from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel -from ..blenderbot_small.modeling_blenderbot_small import ( - BlenderbotSmallForCausalLM, - BlenderbotSmallForConditionalGeneration, - BlenderbotSmallModel, -) -from ..camembert.modeling_camembert import ( - CamembertForCausalLM, - CamembertForMaskedLM, - CamembertForMultipleChoice, - CamembertForQuestionAnswering, - CamembertForSequenceClassification, - CamembertForTokenClassification, - CamembertModel, -) -from ..canine.modeling_canine import ( - CanineForMultipleChoice, - CanineForQuestionAnswering, - CanineForSequenceClassification, - CanineForTokenClassification, - CanineModel, -) -from ..clip.modeling_clip import CLIPModel -from ..convbert.modeling_convbert import ( - ConvBertForMaskedLM, - ConvBertForMultipleChoice, - ConvBertForQuestionAnswering, - ConvBertForSequenceClassification, - ConvBertForTokenClassification, - ConvBertModel, -) -from ..ctrl.modeling_ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel -from ..deberta.modeling_deberta import ( - DebertaForMaskedLM, - DebertaForQuestionAnswering, - DebertaForSequenceClassification, - DebertaForTokenClassification, - DebertaModel, -) -from ..deberta_v2.modeling_deberta_v2 import ( - DebertaV2ForMaskedLM, - DebertaV2ForQuestionAnswering, - DebertaV2ForSequenceClassification, - DebertaV2ForTokenClassification, - DebertaV2Model, -) -from ..deit.modeling_deit import DeiTForImageClassification, DeiTForImageClassificationWithTeacher, DeiTModel -from ..detr.modeling_detr import DetrForObjectDetection, DetrModel -from ..distilbert.modeling_distilbert import ( - DistilBertForMaskedLM, - DistilBertForMultipleChoice, - DistilBertForQuestionAnswering, - DistilBertForSequenceClassification, - DistilBertForTokenClassification, - DistilBertModel, -) -from ..dpr.modeling_dpr import DPRQuestionEncoder -from ..electra.modeling_electra import ( - ElectraForMaskedLM, - ElectraForMultipleChoice, - ElectraForPreTraining, - ElectraForQuestionAnswering, - ElectraForSequenceClassification, - ElectraForTokenClassification, - ElectraModel, -) -from ..encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel -from ..flaubert.modeling_flaubert import ( - FlaubertForMultipleChoice, - FlaubertForQuestionAnsweringSimple, - FlaubertForSequenceClassification, - FlaubertForTokenClassification, - FlaubertModel, - FlaubertWithLMHeadModel, -) -from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel -from ..funnel.modeling_funnel import ( - FunnelBaseModel, - FunnelForMaskedLM, - FunnelForMultipleChoice, - FunnelForPreTraining, - FunnelForQuestionAnswering, - FunnelForSequenceClassification, - FunnelForTokenClassification, - FunnelModel, -) -from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model -from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel -from ..hubert.modeling_hubert import HubertModel -from ..ibert.modeling_ibert import ( - IBertForMaskedLM, - IBertForMultipleChoice, - IBertForQuestionAnswering, - IBertForSequenceClassification, - IBertForTokenClassification, - IBertModel, -) -from ..layoutlm.modeling_layoutlm import ( - LayoutLMForMaskedLM, - LayoutLMForSequenceClassification, - LayoutLMForTokenClassification, - LayoutLMModel, -) -from ..led.modeling_led import ( - LEDForConditionalGeneration, - LEDForQuestionAnswering, - LEDForSequenceClassification, - LEDModel, -) -from ..longformer.modeling_longformer import ( - LongformerForMaskedLM, - LongformerForMultipleChoice, - LongformerForQuestionAnswering, - LongformerForSequenceClassification, - LongformerForTokenClassification, - LongformerModel, -) -from ..luke.modeling_luke import LukeModel -from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel -from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model -from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel -from ..mbart.modeling_mbart import ( - MBartForCausalLM, - MBartForConditionalGeneration, - MBartForQuestionAnswering, - MBartForSequenceClassification, - MBartModel, -) -from ..megatron_bert.modeling_megatron_bert import ( - MegatronBertForCausalLM, - MegatronBertForMaskedLM, - MegatronBertForMultipleChoice, - MegatronBertForNextSentencePrediction, - MegatronBertForPreTraining, - MegatronBertForQuestionAnswering, - MegatronBertForSequenceClassification, - MegatronBertForTokenClassification, - MegatronBertModel, -) -from ..mobilebert.modeling_mobilebert import ( - MobileBertForMaskedLM, - MobileBertForMultipleChoice, - MobileBertForNextSentencePrediction, - MobileBertForPreTraining, - MobileBertForQuestionAnswering, - MobileBertForSequenceClassification, - MobileBertForTokenClassification, - MobileBertModel, -) -from ..mpnet.modeling_mpnet import ( - MPNetForMaskedLM, - MPNetForMultipleChoice, - MPNetForQuestionAnswering, - MPNetForSequenceClassification, - MPNetForTokenClassification, - MPNetModel, -) -from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model -from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel -from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel -from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel -from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function - RagModel, - RagSequenceForGeneration, - RagTokenForGeneration, -) -from ..reformer.modeling_reformer import ( - ReformerForMaskedLM, - ReformerForQuestionAnswering, - ReformerForSequenceClassification, - ReformerModel, - ReformerModelWithLMHead, -) -from ..rembert.modeling_rembert import ( - RemBertForCausalLM, - RemBertForMaskedLM, - RemBertForMultipleChoice, - RemBertForQuestionAnswering, - RemBertForSequenceClassification, - RemBertForTokenClassification, - RemBertModel, -) -from ..retribert.modeling_retribert import RetriBertModel -from ..roberta.modeling_roberta import ( - RobertaForCausalLM, - RobertaForMaskedLM, - RobertaForMultipleChoice, - RobertaForQuestionAnswering, - RobertaForSequenceClassification, - RobertaForTokenClassification, - RobertaModel, -) -from ..roformer.modeling_roformer import ( - RoFormerForCausalLM, - RoFormerForMaskedLM, - RoFormerForMultipleChoice, - RoFormerForQuestionAnswering, - RoFormerForSequenceClassification, - RoFormerForTokenClassification, - RoFormerModel, -) -from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel -from ..squeezebert.modeling_squeezebert import ( - SqueezeBertForMaskedLM, - SqueezeBertForMultipleChoice, - SqueezeBertForQuestionAnswering, - SqueezeBertForSequenceClassification, - SqueezeBertForTokenClassification, - SqueezeBertModel, -) -from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model -from ..tapas.modeling_tapas import ( - TapasForMaskedLM, - TapasForQuestionAnswering, - TapasForSequenceClassification, - TapasModel, -) -from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel -from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel -from ..vit.modeling_vit import ViTForImageClassification, ViTModel -from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2Model -from ..xlm.modeling_xlm import ( - XLMForMultipleChoice, - XLMForQuestionAnsweringSimple, - XLMForSequenceClassification, - XLMForTokenClassification, - XLMModel, - XLMWithLMHeadModel, -) -from ..xlm_prophetnet.modeling_xlm_prophetnet import ( - XLMProphetNetForCausalLM, - XLMProphetNetForConditionalGeneration, - XLMProphetNetModel, -) -from ..xlm_roberta.modeling_xlm_roberta import ( - XLMRobertaForCausalLM, - XLMRobertaForMaskedLM, - XLMRobertaForMultipleChoice, - XLMRobertaForQuestionAnswering, - XLMRobertaForSequenceClassification, - XLMRobertaForTokenClassification, - XLMRobertaModel, -) -from ..xlnet.modeling_xlnet import ( - XLNetForMultipleChoice, - XLNetForQuestionAnsweringSimple, - XLNetForSequenceClassification, - XLNetForTokenClassification, - XLNetLMHeadModel, - XLNetModel, -) -from .auto_factory import _BaseAutoModelClass, auto_class_update -from .configuration_auto import ( - AlbertConfig, - BartConfig, - BeitConfig, - BertConfig, - BertGenerationConfig, - BigBirdConfig, - BigBirdPegasusConfig, - BlenderbotConfig, - BlenderbotSmallConfig, - CamembertConfig, - CanineConfig, - CLIPConfig, - ConvBertConfig, - CTRLConfig, - DebertaConfig, - DebertaV2Config, - DeiTConfig, - DetrConfig, - DistilBertConfig, - DPRConfig, - ElectraConfig, - EncoderDecoderConfig, - FlaubertConfig, - FSMTConfig, - FunnelConfig, - GPT2Config, - GPTNeoConfig, - HubertConfig, - IBertConfig, - LayoutLMConfig, - LEDConfig, - LongformerConfig, - LukeConfig, - LxmertConfig, - M2M100Config, - MarianConfig, - MBartConfig, - MegatronBertConfig, - MobileBertConfig, - MPNetConfig, - MT5Config, - OpenAIGPTConfig, - PegasusConfig, - ProphetNetConfig, - ReformerConfig, - RemBertConfig, - RetriBertConfig, - RobertaConfig, - RoFormerConfig, - Speech2TextConfig, - SqueezeBertConfig, - T5Config, - TapasConfig, - TransfoXLConfig, - VisualBertConfig, - ViTConfig, - Wav2Vec2Config, - XLMConfig, - XLMProphetNetConfig, - XLMRobertaConfig, - XLNetConfig, -) +from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .configuration_auto import CONFIG_MAPPING_NAMES logger = logging.get_logger(__name__) -MODEL_MAPPING = OrderedDict( +MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping - (BeitConfig, BeitModel), - (RemBertConfig, RemBertModel), - (VisualBertConfig, VisualBertModel), - (CanineConfig, CanineModel), - (RoFormerConfig, RoFormerModel), - (CLIPConfig, CLIPModel), - (BigBirdPegasusConfig, BigBirdPegasusModel), - (DeiTConfig, DeiTModel), - (LukeConfig, LukeModel), - (DetrConfig, DetrModel), - (GPTNeoConfig, GPTNeoModel), - (BigBirdConfig, BigBirdModel), - (Speech2TextConfig, Speech2TextModel), - (ViTConfig, ViTModel), - (Wav2Vec2Config, Wav2Vec2Model), - (HubertConfig, HubertModel), - (M2M100Config, M2M100Model), - (ConvBertConfig, ConvBertModel), - (LEDConfig, LEDModel), - (BlenderbotSmallConfig, BlenderbotSmallModel), - (RetriBertConfig, RetriBertModel), - (MT5Config, MT5Model), - (T5Config, T5Model), - (PegasusConfig, PegasusModel), - (MarianConfig, MarianMTModel), - (MBartConfig, MBartModel), - (BlenderbotConfig, BlenderbotModel), - (DistilBertConfig, DistilBertModel), - (AlbertConfig, AlbertModel), - (CamembertConfig, CamembertModel), - (XLMRobertaConfig, XLMRobertaModel), - (BartConfig, BartModel), - (LongformerConfig, LongformerModel), - (RobertaConfig, RobertaModel), - (LayoutLMConfig, LayoutLMModel), - (SqueezeBertConfig, SqueezeBertModel), - (BertConfig, BertModel), - (OpenAIGPTConfig, OpenAIGPTModel), - (GPT2Config, GPT2Model), - (MegatronBertConfig, MegatronBertModel), - (MobileBertConfig, MobileBertModel), - (TransfoXLConfig, TransfoXLModel), - (XLNetConfig, XLNetModel), - (FlaubertConfig, FlaubertModel), - (FSMTConfig, FSMTModel), - (XLMConfig, XLMModel), - (CTRLConfig, CTRLModel), - (ElectraConfig, ElectraModel), - (ReformerConfig, ReformerModel), - (FunnelConfig, (FunnelModel, FunnelBaseModel)), - (LxmertConfig, LxmertModel), - (BertGenerationConfig, BertGenerationEncoder), - (DebertaConfig, DebertaModel), - (DebertaV2Config, DebertaV2Model), - (DPRConfig, DPRQuestionEncoder), - (XLMProphetNetConfig, XLMProphetNetModel), - (ProphetNetConfig, ProphetNetModel), - (MPNetConfig, MPNetModel), - (TapasConfig, TapasModel), - (MarianConfig, MarianModel), - (IBertConfig, IBertModel), + ("beit", "BeitModel"), + ("rembert", "RemBertModel"), + ("visual_bert", "VisualBertModel"), + ("canine", "CanineModel"), + ("roformer", "RoFormerModel"), + ("clip", "CLIPModel"), + ("bigbird_pegasus", "BigBirdPegasusModel"), + ("deit", "DeiTModel"), + ("luke", "LukeModel"), + ("detr", "DetrModel"), + ("gpt_neo", "GPTNeoModel"), + ("big_bird", "BigBirdModel"), + ("speech_to_text", "Speech2TextModel"), + ("vit", "ViTModel"), + ("wav2vec2", "Wav2Vec2Model"), + ("hubert", "HubertModel"), + ("m2m_100", "M2M100Model"), + ("convbert", "ConvBertModel"), + ("led", "LEDModel"), + ("blenderbot-small", "BlenderbotSmallModel"), + ("retribert", "RetriBertModel"), + ("mt5", "MT5Model"), + ("t5", "T5Model"), + ("pegasus", "PegasusModel"), + ("marian", "MarianModel"), + ("mbart", "MBartModel"), + ("blenderbot", "BlenderbotModel"), + ("distilbert", "DistilBertModel"), + ("albert", "AlbertModel"), + ("camembert", "CamembertModel"), + ("xlm-roberta", "XLMRobertaModel"), + ("bart", "BartModel"), + ("longformer", "LongformerModel"), + ("roberta", "RobertaModel"), + ("layoutlm", "LayoutLMModel"), + ("squeezebert", "SqueezeBertModel"), + ("bert", "BertModel"), + ("openai-gpt", "OpenAIGPTModel"), + ("gpt2", "GPT2Model"), + ("megatron-bert", "MegatronBertModel"), + ("mobilebert", "MobileBertModel"), + ("transfo-xl", "TransfoXLModel"), + ("xlnet", "XLNetModel"), + ("flaubert", "FlaubertModel"), + ("fsmt", "FSMTModel"), + ("xlm", "XLMModel"), + ("ctrl", "CTRLModel"), + ("electra", "ElectraModel"), + ("reformer", "ReformerModel"), + ("funnel", ("FunnelModel", "FunnelBaseModel")), + ("lxmert", "LxmertModel"), + ("bert-generation", "BertGenerationEncoder"), + ("deberta", "DebertaModel"), + ("deberta-v2", "DebertaV2Model"), + ("dpr", "DPRQuestionEncoder"), + ("xlm-prophetnet", "XLMProphetNetModel"), + ("prophetnet", "ProphetNetModel"), + ("mpnet", "MPNetModel"), + ("tapas", "TapasModel"), + ("ibert", "IBertModel"), ] ) -MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( +MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping - (VisualBertConfig, VisualBertForPreTraining), - (LayoutLMConfig, LayoutLMForMaskedLM), - (RetriBertConfig, RetriBertModel), - (T5Config, T5ForConditionalGeneration), - (DistilBertConfig, DistilBertForMaskedLM), - (AlbertConfig, AlbertForPreTraining), - (CamembertConfig, CamembertForMaskedLM), - (XLMRobertaConfig, XLMRobertaForMaskedLM), - (BartConfig, BartForConditionalGeneration), - (FSMTConfig, FSMTForConditionalGeneration), - (LongformerConfig, LongformerForMaskedLM), - (RobertaConfig, RobertaForMaskedLM), - (SqueezeBertConfig, SqueezeBertForMaskedLM), - (BertConfig, BertForPreTraining), - (BigBirdConfig, BigBirdForPreTraining), - (OpenAIGPTConfig, OpenAIGPTLMHeadModel), - (GPT2Config, GPT2LMHeadModel), - (MegatronBertConfig, MegatronBertForPreTraining), - (MobileBertConfig, MobileBertForPreTraining), - (TransfoXLConfig, TransfoXLLMHeadModel), - (XLNetConfig, XLNetLMHeadModel), - (FlaubertConfig, FlaubertWithLMHeadModel), - (XLMConfig, XLMWithLMHeadModel), - (CTRLConfig, CTRLLMHeadModel), - (ElectraConfig, ElectraForPreTraining), - (LxmertConfig, LxmertForPreTraining), - (FunnelConfig, FunnelForPreTraining), - (MPNetConfig, MPNetForMaskedLM), - (TapasConfig, TapasForMaskedLM), - (IBertConfig, IBertForMaskedLM), - (DebertaConfig, DebertaForMaskedLM), - (DebertaV2Config, DebertaV2ForMaskedLM), - (Wav2Vec2Config, Wav2Vec2ForPreTraining), + ("visual_bert", "VisualBertForPreTraining"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("retribert", "RetriBertModel"), + ("t5", "T5ForConditionalGeneration"), + ("distilbert", "DistilBertForMaskedLM"), + ("albert", "AlbertForPreTraining"), + ("camembert", "CamembertForMaskedLM"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("fsmt", "FSMTForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("bert", "BertForPreTraining"), + ("big_bird", "BigBirdForPreTraining"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("megatron-bert", "MegatronBertForPreTraining"), + ("mobilebert", "MobileBertForPreTraining"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("xlnet", "XLNetLMHeadModel"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("ctrl", "CTRLLMHeadModel"), + ("electra", "ElectraForPreTraining"), + ("lxmert", "LxmertForPreTraining"), + ("funnel", "FunnelForPreTraining"), + ("mpnet", "MPNetForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("ibert", "IBertForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("wav2vec2", "Wav2Vec2ForPreTraining"), ] ) -MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( +MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping - (RemBertConfig, RemBertForMaskedLM), - (RoFormerConfig, RoFormerForMaskedLM), - (BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), - (GPTNeoConfig, GPTNeoForCausalLM), - (BigBirdConfig, BigBirdForMaskedLM), - (Speech2TextConfig, Speech2TextForConditionalGeneration), - (Wav2Vec2Config, Wav2Vec2ForMaskedLM), - (M2M100Config, M2M100ForConditionalGeneration), - (ConvBertConfig, ConvBertForMaskedLM), - (LEDConfig, LEDForConditionalGeneration), - (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), - (LayoutLMConfig, LayoutLMForMaskedLM), - (T5Config, T5ForConditionalGeneration), - (DistilBertConfig, DistilBertForMaskedLM), - (AlbertConfig, AlbertForMaskedLM), - (CamembertConfig, CamembertForMaskedLM), - (XLMRobertaConfig, XLMRobertaForMaskedLM), - (MarianConfig, MarianMTModel), - (FSMTConfig, FSMTForConditionalGeneration), - (BartConfig, BartForConditionalGeneration), - (LongformerConfig, LongformerForMaskedLM), - (RobertaConfig, RobertaForMaskedLM), - (SqueezeBertConfig, SqueezeBertForMaskedLM), - (BertConfig, BertForMaskedLM), - (OpenAIGPTConfig, OpenAIGPTLMHeadModel), - (GPT2Config, GPT2LMHeadModel), - (MegatronBertConfig, MegatronBertForMaskedLM), - (MobileBertConfig, MobileBertForMaskedLM), - (TransfoXLConfig, TransfoXLLMHeadModel), - (XLNetConfig, XLNetLMHeadModel), - (FlaubertConfig, FlaubertWithLMHeadModel), - (XLMConfig, XLMWithLMHeadModel), - (CTRLConfig, CTRLLMHeadModel), - (ElectraConfig, ElectraForMaskedLM), - (EncoderDecoderConfig, EncoderDecoderModel), - (ReformerConfig, ReformerModelWithLMHead), - (FunnelConfig, FunnelForMaskedLM), - (MPNetConfig, MPNetForMaskedLM), - (TapasConfig, TapasForMaskedLM), - (DebertaConfig, DebertaForMaskedLM), - (DebertaV2Config, DebertaV2ForMaskedLM), - (IBertConfig, IBertForMaskedLM), - (MegatronBertConfig, MegatronBertForCausalLM), + ("rembert", "RemBertForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("speech_to_text", "Speech2TextForConditionalGeneration"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("convbert", "ConvBertForMaskedLM"), + ("led", "LEDForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("t5", "T5ForConditionalGeneration"), + ("distilbert", "DistilBertForMaskedLM"), + ("albert", "AlbertForMaskedLM"), + ("camembert", "CamembertForMaskedLM"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("marian", "MarianMTModel"), + ("fsmt", "FSMTForConditionalGeneration"), + ("bart", "BartForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("bert", "BertForMaskedLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("xlnet", "XLNetLMHeadModel"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("ctrl", "CTRLLMHeadModel"), + ("electra", "ElectraForMaskedLM"), + ("encoder-decoder", "EncoderDecoderModel"), + ("reformer", "ReformerModelWithLMHead"), + ("funnel", "FunnelForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("ibert", "IBertForMaskedLM"), ] ) -MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - (RemBertConfig, RemBertForCausalLM), - (RoFormerConfig, RoFormerForCausalLM), - (BigBirdPegasusConfig, BigBirdPegasusForCausalLM), - (GPTNeoConfig, GPTNeoForCausalLM), - (BigBirdConfig, BigBirdForCausalLM), - (CamembertConfig, CamembertForCausalLM), - (XLMRobertaConfig, XLMRobertaForCausalLM), - (RobertaConfig, RobertaForCausalLM), - (BertConfig, BertLMHeadModel), - (OpenAIGPTConfig, OpenAIGPTLMHeadModel), - (GPT2Config, GPT2LMHeadModel), - (TransfoXLConfig, TransfoXLLMHeadModel), - (XLNetConfig, XLNetLMHeadModel), - ( - XLMConfig, - XLMWithLMHeadModel, - ), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now - (CTRLConfig, CTRLLMHeadModel), - (ReformerConfig, ReformerModelWithLMHead), - (BertGenerationConfig, BertGenerationDecoder), - (XLMProphetNetConfig, XLMProphetNetForCausalLM), - (ProphetNetConfig, ProphetNetForCausalLM), - (BartConfig, BartForCausalLM), - (MBartConfig, MBartForCausalLM), - (PegasusConfig, PegasusForCausalLM), - (MarianConfig, MarianForCausalLM), - (BlenderbotConfig, BlenderbotForCausalLM), - (BlenderbotSmallConfig, BlenderbotSmallForCausalLM), - (MegatronBertConfig, MegatronBertForCausalLM), + ("rembert", "RemBertForCausalLM"), + ("roformer", "RoFormerForCausalLM"), + ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("big_bird", "BigBirdForCausalLM"), + ("camembert", "CamembertForCausalLM"), + ("xlm-roberta", "XLMRobertaForCausalLM"), + ("roberta", "RobertaForCausalLM"), + ("bert", "BertLMHeadModel"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("gpt2", "GPT2LMHeadModel"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("xlnet", "XLNetLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("ctrl", "CTRLLMHeadModel"), + ("reformer", "ReformerModelWithLMHead"), + ("bert-generation", "BertGenerationDecoder"), + ("xlm-prophetnet", "XLMProphetNetForCausalLM"), + ("prophetnet", "ProphetNetForCausalLM"), + ("bart", "BartForCausalLM"), + ("mbart", "MBartForCausalLM"), + ("pegasus", "PegasusForCausalLM"), + ("marian", "MarianForCausalLM"), + ("blenderbot", "BlenderbotForCausalLM"), + ("blenderbot-small", "BlenderbotSmallForCausalLM"), + ("megatron-bert", "MegatronBertForCausalLM"), ] ) -MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image Classification mapping - (ViTConfig, ViTForImageClassification), - (DeiTConfig, (DeiTForImageClassification, DeiTForImageClassificationWithTeacher)), - (BeitConfig, BeitForImageClassification), + ("vit", "ViTForImageClassification"), + ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), + ("beit", "BeitForImageClassification"), ] ) -MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( +MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - (RemBertConfig, RemBertForMaskedLM), - (RoFormerConfig, RoFormerForMaskedLM), - (BigBirdConfig, BigBirdForMaskedLM), - (Wav2Vec2Config, Wav2Vec2ForMaskedLM), - (ConvBertConfig, ConvBertForMaskedLM), - (LayoutLMConfig, LayoutLMForMaskedLM), - (DistilBertConfig, DistilBertForMaskedLM), - (AlbertConfig, AlbertForMaskedLM), - (BartConfig, BartForConditionalGeneration), - (MBartConfig, MBartForConditionalGeneration), - (CamembertConfig, CamembertForMaskedLM), - (XLMRobertaConfig, XLMRobertaForMaskedLM), - (LongformerConfig, LongformerForMaskedLM), - (RobertaConfig, RobertaForMaskedLM), - (SqueezeBertConfig, SqueezeBertForMaskedLM), - (BertConfig, BertForMaskedLM), - (MegatronBertConfig, MegatronBertForMaskedLM), - (MobileBertConfig, MobileBertForMaskedLM), - (FlaubertConfig, FlaubertWithLMHeadModel), - (XLMConfig, XLMWithLMHeadModel), - (ElectraConfig, ElectraForMaskedLM), - (ReformerConfig, ReformerForMaskedLM), - (FunnelConfig, FunnelForMaskedLM), - (MPNetConfig, MPNetForMaskedLM), - (TapasConfig, TapasForMaskedLM), - (DebertaConfig, DebertaForMaskedLM), - (DebertaV2Config, DebertaV2ForMaskedLM), - (IBertConfig, IBertForMaskedLM), + ("rembert", "RemBertForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("convbert", "ConvBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("distilbert", "DistilBertForMaskedLM"), + ("albert", "AlbertForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("mbart", "MBartForConditionalGeneration"), + ("camembert", "CamembertForMaskedLM"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("longformer", "LongformerForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("bert", "BertForMaskedLM"), + ("megatron-bert", "MegatronBertForMaskedLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("xlm", "XLMWithLMHeadModel"), + ("electra", "ElectraForMaskedLM"), + ("reformer", "ReformerForMaskedLM"), + ("funnel", "FunnelForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), + ("tapas", "TapasForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("ibert", "IBertForMaskedLM"), ] ) -MODEL_FOR_OBJECT_DETECTION_MAPPING = OrderedDict( +MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( [ # Model for Object Detection mapping - (DetrConfig, DetrForObjectDetection), + ("detr", "DetrForObjectDetection"), ] ) -MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - (BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), - (M2M100Config, M2M100ForConditionalGeneration), - (LEDConfig, LEDForConditionalGeneration), - (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), - (MT5Config, MT5ForConditionalGeneration), - (T5Config, T5ForConditionalGeneration), - (PegasusConfig, PegasusForConditionalGeneration), - (MarianConfig, MarianMTModel), - (MBartConfig, MBartForConditionalGeneration), - (BlenderbotConfig, BlenderbotForConditionalGeneration), - (BartConfig, BartForConditionalGeneration), - (FSMTConfig, FSMTForConditionalGeneration), - (EncoderDecoderConfig, EncoderDecoderModel), - (XLMProphetNetConfig, XLMProphetNetForConditionalGeneration), - (ProphetNetConfig, ProphetNetForConditionalGeneration), + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("led", "LEDForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("mt5", "MT5ForConditionalGeneration"), + ("t5", "T5ForConditionalGeneration"), + ("pegasus", "PegasusForConditionalGeneration"), + ("marian", "MarianMTModel"), + ("mbart", "MBartForConditionalGeneration"), + ("blenderbot", "BlenderbotForConditionalGeneration"), + ("bart", "BartForConditionalGeneration"), + ("fsmt", "FSMTForConditionalGeneration"), + ("encoder-decoder", "EncoderDecoderModel"), + ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), + ("prophetnet", "ProphetNetForConditionalGeneration"), ] ) -MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - (RemBertConfig, RemBertForSequenceClassification), - (CanineConfig, CanineForSequenceClassification), - (RoFormerConfig, RoFormerForSequenceClassification), - (BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification), - (BigBirdConfig, BigBirdForSequenceClassification), - (ConvBertConfig, ConvBertForSequenceClassification), - (LEDConfig, LEDForSequenceClassification), - (DistilBertConfig, DistilBertForSequenceClassification), - (AlbertConfig, AlbertForSequenceClassification), - (CamembertConfig, CamembertForSequenceClassification), - (XLMRobertaConfig, XLMRobertaForSequenceClassification), - (MBartConfig, MBartForSequenceClassification), - (BartConfig, BartForSequenceClassification), - (LongformerConfig, LongformerForSequenceClassification), - (RobertaConfig, RobertaForSequenceClassification), - (SqueezeBertConfig, SqueezeBertForSequenceClassification), - (LayoutLMConfig, LayoutLMForSequenceClassification), - (BertConfig, BertForSequenceClassification), - (XLNetConfig, XLNetForSequenceClassification), - (MegatronBertConfig, MegatronBertForSequenceClassification), - (MobileBertConfig, MobileBertForSequenceClassification), - (FlaubertConfig, FlaubertForSequenceClassification), - (XLMConfig, XLMForSequenceClassification), - (ElectraConfig, ElectraForSequenceClassification), - (FunnelConfig, FunnelForSequenceClassification), - (DebertaConfig, DebertaForSequenceClassification), - (DebertaV2Config, DebertaV2ForSequenceClassification), - (GPT2Config, GPT2ForSequenceClassification), - (GPTNeoConfig, GPTNeoForSequenceClassification), - (OpenAIGPTConfig, OpenAIGPTForSequenceClassification), - (ReformerConfig, ReformerForSequenceClassification), - (CTRLConfig, CTRLForSequenceClassification), - (TransfoXLConfig, TransfoXLForSequenceClassification), - (MPNetConfig, MPNetForSequenceClassification), - (TapasConfig, TapasForSequenceClassification), - (IBertConfig, IBertForSequenceClassification), + ("rembert", "RemBertForSequenceClassification"), + ("canine", "CanineForSequenceClassification"), + ("roformer", "RoFormerForSequenceClassification"), + ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), + ("big_bird", "BigBirdForSequenceClassification"), + ("convbert", "ConvBertForSequenceClassification"), + ("led", "LEDForSequenceClassification"), + ("distilbert", "DistilBertForSequenceClassification"), + ("albert", "AlbertForSequenceClassification"), + ("camembert", "CamembertForSequenceClassification"), + ("xlm-roberta", "XLMRobertaForSequenceClassification"), + ("mbart", "MBartForSequenceClassification"), + ("bart", "BartForSequenceClassification"), + ("longformer", "LongformerForSequenceClassification"), + ("roberta", "RobertaForSequenceClassification"), + ("squeezebert", "SqueezeBertForSequenceClassification"), + ("layoutlm", "LayoutLMForSequenceClassification"), + ("bert", "BertForSequenceClassification"), + ("xlnet", "XLNetForSequenceClassification"), + ("megatron-bert", "MegatronBertForSequenceClassification"), + ("mobilebert", "MobileBertForSequenceClassification"), + ("flaubert", "FlaubertForSequenceClassification"), + ("xlm", "XLMForSequenceClassification"), + ("electra", "ElectraForSequenceClassification"), + ("funnel", "FunnelForSequenceClassification"), + ("deberta", "DebertaForSequenceClassification"), + ("deberta-v2", "DebertaV2ForSequenceClassification"), + ("gpt2", "GPT2ForSequenceClassification"), + ("gpt_neo", "GPTNeoForSequenceClassification"), + ("openai-gpt", "OpenAIGPTForSequenceClassification"), + ("reformer", "ReformerForSequenceClassification"), + ("ctrl", "CTRLForSequenceClassification"), + ("transfo-xl", "TransfoXLForSequenceClassification"), + ("mpnet", "MPNetForSequenceClassification"), + ("tapas", "TapasForSequenceClassification"), + ("ibert", "IBertForSequenceClassification"), ] ) -MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( +MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - (RemBertConfig, RemBertForQuestionAnswering), - (CanineConfig, CanineForQuestionAnswering), - (RoFormerConfig, RoFormerForQuestionAnswering), - (BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering), - (BigBirdConfig, BigBirdForQuestionAnswering), - (ConvBertConfig, ConvBertForQuestionAnswering), - (LEDConfig, LEDForQuestionAnswering), - (DistilBertConfig, DistilBertForQuestionAnswering), - (AlbertConfig, AlbertForQuestionAnswering), - (CamembertConfig, CamembertForQuestionAnswering), - (BartConfig, BartForQuestionAnswering), - (MBartConfig, MBartForQuestionAnswering), - (LongformerConfig, LongformerForQuestionAnswering), - (XLMRobertaConfig, XLMRobertaForQuestionAnswering), - (RobertaConfig, RobertaForQuestionAnswering), - (SqueezeBertConfig, SqueezeBertForQuestionAnswering), - (BertConfig, BertForQuestionAnswering), - (XLNetConfig, XLNetForQuestionAnsweringSimple), - (FlaubertConfig, FlaubertForQuestionAnsweringSimple), - (MegatronBertConfig, MegatronBertForQuestionAnswering), - (MobileBertConfig, MobileBertForQuestionAnswering), - (XLMConfig, XLMForQuestionAnsweringSimple), - (ElectraConfig, ElectraForQuestionAnswering), - (ReformerConfig, ReformerForQuestionAnswering), - (FunnelConfig, FunnelForQuestionAnswering), - (LxmertConfig, LxmertForQuestionAnswering), - (MPNetConfig, MPNetForQuestionAnswering), - (DebertaConfig, DebertaForQuestionAnswering), - (DebertaV2Config, DebertaV2ForQuestionAnswering), - (IBertConfig, IBertForQuestionAnswering), + ("rembert", "RemBertForQuestionAnswering"), + ("canine", "CanineForQuestionAnswering"), + ("roformer", "RoFormerForQuestionAnswering"), + ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), + ("big_bird", "BigBirdForQuestionAnswering"), + ("convbert", "ConvBertForQuestionAnswering"), + ("led", "LEDForQuestionAnswering"), + ("distilbert", "DistilBertForQuestionAnswering"), + ("albert", "AlbertForQuestionAnswering"), + ("camembert", "CamembertForQuestionAnswering"), + ("bart", "BartForQuestionAnswering"), + ("mbart", "MBartForQuestionAnswering"), + ("longformer", "LongformerForQuestionAnswering"), + ("xlm-roberta", "XLMRobertaForQuestionAnswering"), + ("roberta", "RobertaForQuestionAnswering"), + ("squeezebert", "SqueezeBertForQuestionAnswering"), + ("bert", "BertForQuestionAnswering"), + ("xlnet", "XLNetForQuestionAnsweringSimple"), + ("flaubert", "FlaubertForQuestionAnsweringSimple"), + ("megatron-bert", "MegatronBertForQuestionAnswering"), + ("mobilebert", "MobileBertForQuestionAnswering"), + ("xlm", "XLMForQuestionAnsweringSimple"), + ("electra", "ElectraForQuestionAnswering"), + ("reformer", "ReformerForQuestionAnswering"), + ("funnel", "FunnelForQuestionAnswering"), + ("lxmert", "LxmertForQuestionAnswering"), + ("mpnet", "MPNetForQuestionAnswering"), + ("deberta", "DebertaForQuestionAnswering"), + ("deberta-v2", "DebertaV2ForQuestionAnswering"), + ("ibert", "IBertForQuestionAnswering"), ] ) -MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict( +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Table Question Answering mapping - (TapasConfig, TapasForQuestionAnswering), + ("tapas", "TapasForQuestionAnswering"), ] ) -MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - (RemBertConfig, RemBertForTokenClassification), - (CanineConfig, CanineForTokenClassification), - (RoFormerConfig, RoFormerForTokenClassification), - (BigBirdConfig, BigBirdForTokenClassification), - (ConvBertConfig, ConvBertForTokenClassification), - (LayoutLMConfig, LayoutLMForTokenClassification), - (DistilBertConfig, DistilBertForTokenClassification), - (CamembertConfig, CamembertForTokenClassification), - (FlaubertConfig, FlaubertForTokenClassification), - (XLMConfig, XLMForTokenClassification), - (XLMRobertaConfig, XLMRobertaForTokenClassification), - (LongformerConfig, LongformerForTokenClassification), - (RobertaConfig, RobertaForTokenClassification), - (SqueezeBertConfig, SqueezeBertForTokenClassification), - (BertConfig, BertForTokenClassification), - (MegatronBertConfig, MegatronBertForTokenClassification), - (MobileBertConfig, MobileBertForTokenClassification), - (XLNetConfig, XLNetForTokenClassification), - (AlbertConfig, AlbertForTokenClassification), - (ElectraConfig, ElectraForTokenClassification), - (FlaubertConfig, FlaubertForTokenClassification), - (FunnelConfig, FunnelForTokenClassification), - (MPNetConfig, MPNetForTokenClassification), - (DebertaConfig, DebertaForTokenClassification), - (DebertaV2Config, DebertaV2ForTokenClassification), - (IBertConfig, IBertForTokenClassification), + ("rembert", "RemBertForTokenClassification"), + ("canine", "CanineForTokenClassification"), + ("roformer", "RoFormerForTokenClassification"), + ("big_bird", "BigBirdForTokenClassification"), + ("convbert", "ConvBertForTokenClassification"), + ("layoutlm", "LayoutLMForTokenClassification"), + ("distilbert", "DistilBertForTokenClassification"), + ("camembert", "CamembertForTokenClassification"), + ("flaubert", "FlaubertForTokenClassification"), + ("xlm", "XLMForTokenClassification"), + ("xlm-roberta", "XLMRobertaForTokenClassification"), + ("longformer", "LongformerForTokenClassification"), + ("roberta", "RobertaForTokenClassification"), + ("squeezebert", "SqueezeBertForTokenClassification"), + ("bert", "BertForTokenClassification"), + ("megatron-bert", "MegatronBertForTokenClassification"), + ("mobilebert", "MobileBertForTokenClassification"), + ("xlnet", "XLNetForTokenClassification"), + ("albert", "AlbertForTokenClassification"), + ("electra", "ElectraForTokenClassification"), + ("funnel", "FunnelForTokenClassification"), + ("mpnet", "MPNetForTokenClassification"), + ("deberta", "DebertaForTokenClassification"), + ("deberta-v2", "DebertaV2ForTokenClassification"), + ("ibert", "IBertForTokenClassification"), ] ) -MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( +MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - (RemBertConfig, RemBertForMultipleChoice), - (CanineConfig, CanineForMultipleChoice), - (RoFormerConfig, RoFormerForMultipleChoice), - (BigBirdConfig, BigBirdForMultipleChoice), - (ConvBertConfig, ConvBertForMultipleChoice), - (CamembertConfig, CamembertForMultipleChoice), - (ElectraConfig, ElectraForMultipleChoice), - (XLMRobertaConfig, XLMRobertaForMultipleChoice), - (LongformerConfig, LongformerForMultipleChoice), - (RobertaConfig, RobertaForMultipleChoice), - (SqueezeBertConfig, SqueezeBertForMultipleChoice), - (BertConfig, BertForMultipleChoice), - (DistilBertConfig, DistilBertForMultipleChoice), - (MegatronBertConfig, MegatronBertForMultipleChoice), - (MobileBertConfig, MobileBertForMultipleChoice), - (XLNetConfig, XLNetForMultipleChoice), - (AlbertConfig, AlbertForMultipleChoice), - (XLMConfig, XLMForMultipleChoice), - (FlaubertConfig, FlaubertForMultipleChoice), - (FunnelConfig, FunnelForMultipleChoice), - (MPNetConfig, MPNetForMultipleChoice), - (IBertConfig, IBertForMultipleChoice), + ("rembert", "RemBertForMultipleChoice"), + ("canine", "CanineForMultipleChoice"), + ("roformer", "RoFormerForMultipleChoice"), + ("big_bird", "BigBirdForMultipleChoice"), + ("convbert", "ConvBertForMultipleChoice"), + ("camembert", "CamembertForMultipleChoice"), + ("electra", "ElectraForMultipleChoice"), + ("xlm-roberta", "XLMRobertaForMultipleChoice"), + ("longformer", "LongformerForMultipleChoice"), + ("roberta", "RobertaForMultipleChoice"), + ("squeezebert", "SqueezeBertForMultipleChoice"), + ("bert", "BertForMultipleChoice"), + ("distilbert", "DistilBertForMultipleChoice"), + ("megatron-bert", "MegatronBertForMultipleChoice"), + ("mobilebert", "MobileBertForMultipleChoice"), + ("xlnet", "XLNetForMultipleChoice"), + ("albert", "AlbertForMultipleChoice"), + ("xlm", "XLMForMultipleChoice"), + ("flaubert", "FlaubertForMultipleChoice"), + ("funnel", "FunnelForMultipleChoice"), + ("mpnet", "MPNetForMultipleChoice"), + ("ibert", "IBertForMultipleChoice"), ] ) -MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ - (BertConfig, BertForNextSentencePrediction), - (MegatronBertConfig, MegatronBertForNextSentencePrediction), - (MobileBertConfig, MobileBertForNextSentencePrediction), + ("bert", "BertForNextSentencePrediction"), + ("megatron-bert", "MegatronBertForNextSentencePrediction"), + ("mobilebert", "MobileBertForNextSentencePrediction"), ] ) +MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) +MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) +MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) +MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) +MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) +MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES +) +MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) +MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) + class AutoModel(_BaseAutoModelClass): _model_mapping = MODEL_MAPPING diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 47029243ba9..5e731c58cca 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -18,207 +18,163 @@ from collections import OrderedDict from ...utils import logging -from ..bart.modeling_flax_bart import ( - FlaxBartForConditionalGeneration, - FlaxBartForQuestionAnswering, - FlaxBartForSequenceClassification, - FlaxBartModel, -) -from ..bert.modeling_flax_bert import ( - FlaxBertForMaskedLM, - FlaxBertForMultipleChoice, - FlaxBertForNextSentencePrediction, - FlaxBertForPreTraining, - FlaxBertForQuestionAnswering, - FlaxBertForSequenceClassification, - FlaxBertForTokenClassification, - FlaxBertModel, -) -from ..big_bird.modeling_flax_big_bird import ( - FlaxBigBirdForMaskedLM, - FlaxBigBirdForMultipleChoice, - FlaxBigBirdForPreTraining, - FlaxBigBirdForQuestionAnswering, - FlaxBigBirdForSequenceClassification, - FlaxBigBirdForTokenClassification, - FlaxBigBirdModel, -) -from ..clip.modeling_flax_clip import FlaxCLIPModel -from ..electra.modeling_flax_electra import ( - FlaxElectraForMaskedLM, - FlaxElectraForMultipleChoice, - FlaxElectraForPreTraining, - FlaxElectraForQuestionAnswering, - FlaxElectraForSequenceClassification, - FlaxElectraForTokenClassification, - FlaxElectraModel, -) -from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model -from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel -from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel -from ..mbart.modeling_flax_mbart import ( - FlaxMBartForConditionalGeneration, - FlaxMBartForQuestionAnswering, - FlaxMBartForSequenceClassification, - FlaxMBartModel, -) -from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model -from ..roberta.modeling_flax_roberta import ( - FlaxRobertaForMaskedLM, - FlaxRobertaForMultipleChoice, - FlaxRobertaForQuestionAnswering, - FlaxRobertaForSequenceClassification, - FlaxRobertaForTokenClassification, - FlaxRobertaModel, -) -from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model -from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel -from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model -from .auto_factory import _BaseAutoModelClass, auto_class_update -from .configuration_auto import ( - BartConfig, - BertConfig, - BigBirdConfig, - CLIPConfig, - ElectraConfig, - GPT2Config, - GPTNeoConfig, - MarianConfig, - MBartConfig, - MT5Config, - RobertaConfig, - T5Config, - ViTConfig, - Wav2Vec2Config, -) +from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .configuration_auto import CONFIG_MAPPING_NAMES logger = logging.get_logger(__name__) -FLAX_MODEL_MAPPING = OrderedDict( +FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping - (RobertaConfig, FlaxRobertaModel), - (BertConfig, FlaxBertModel), - (BigBirdConfig, FlaxBigBirdModel), - (BartConfig, FlaxBartModel), - (GPT2Config, FlaxGPT2Model), - (GPTNeoConfig, FlaxGPTNeoModel), - (ElectraConfig, FlaxElectraModel), - (CLIPConfig, FlaxCLIPModel), - (ViTConfig, FlaxViTModel), - (MBartConfig, FlaxMBartModel), - (T5Config, FlaxT5Model), - (MT5Config, FlaxMT5Model), - (Wav2Vec2Config, FlaxWav2Vec2Model), - (MarianConfig, FlaxMarianModel), + ("roberta", "FlaxRobertaModel"), + ("bert", "FlaxBertModel"), + ("big_bird", "FlaxBigBirdModel"), + ("bart", "FlaxBartModel"), + ("gpt2", "FlaxGPT2Model"), + ("gpt_neo", "FlaxGPTNeoModel"), + ("electra", "FlaxElectraModel"), + ("clip", "FlaxCLIPModel"), + ("vit", "FlaxViTModel"), + ("mbart", "FlaxMBartModel"), + ("t5", "FlaxT5Model"), + ("mt5", "FlaxMT5Model"), + ("wav2vec2", "FlaxWav2Vec2Model"), + ("marian", "FlaxMarianModel"), ] ) -FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( +FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping - (RobertaConfig, FlaxRobertaForMaskedLM), - (BertConfig, FlaxBertForPreTraining), - (BigBirdConfig, FlaxBigBirdForPreTraining), - (BartConfig, FlaxBartForConditionalGeneration), - (ElectraConfig, FlaxElectraForPreTraining), - (MBartConfig, FlaxMBartForConditionalGeneration), - (T5Config, FlaxT5ForConditionalGeneration), - (MT5Config, FlaxMT5ForConditionalGeneration), - (Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), + ("roberta", "FlaxRobertaForMaskedLM"), + ("bert", "FlaxBertForPreTraining"), + ("big_bird", "FlaxBigBirdForPreTraining"), + ("bart", "FlaxBartForConditionalGeneration"), + ("electra", "FlaxElectraForPreTraining"), + ("mbart", "FlaxMBartForConditionalGeneration"), + ("t5", "FlaxT5ForConditionalGeneration"), + ("mt5", "FlaxMT5ForConditionalGeneration"), + ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ] ) -FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( +FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - (RobertaConfig, FlaxRobertaForMaskedLM), - (BertConfig, FlaxBertForMaskedLM), - (BigBirdConfig, FlaxBigBirdForMaskedLM), - (BartConfig, FlaxBartForConditionalGeneration), - (ElectraConfig, FlaxElectraForMaskedLM), - (MBartConfig, FlaxMBartForConditionalGeneration), + ("roberta", "FlaxRobertaForMaskedLM"), + ("bert", "FlaxBertForMaskedLM"), + ("big_bird", "FlaxBigBirdForMaskedLM"), + ("bart", "FlaxBartForConditionalGeneration"), + ("electra", "FlaxElectraForMaskedLM"), + ("mbart", "FlaxMBartForConditionalGeneration"), ] ) -FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - (BartConfig, FlaxBartForConditionalGeneration), - (T5Config, FlaxT5ForConditionalGeneration), - (MT5Config, FlaxMT5ForConditionalGeneration), - (MarianConfig, FlaxMarianMTModel), + ("bart", "FlaxBartForConditionalGeneration"), + ("t5", "FlaxT5ForConditionalGeneration"), + ("mt5", "FlaxMT5ForConditionalGeneration"), + ("marian", "FlaxMarianMTModel"), ] ) -FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image-classsification - (ViTConfig, FlaxViTForImageClassification), + ("vit", "FlaxViTForImageClassification"), ] ) -FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( +FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - (GPT2Config, FlaxGPT2LMHeadModel), - (GPTNeoConfig, FlaxGPTNeoForCausalLM), + ("gpt2", "FlaxGPT2LMHeadModel"), + ("gpt_neo", "FlaxGPTNeoForCausalLM"), ] ) -FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( +FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - (RobertaConfig, FlaxRobertaForSequenceClassification), - (BertConfig, FlaxBertForSequenceClassification), - (BigBirdConfig, FlaxBigBirdForSequenceClassification), - (BartConfig, FlaxBartForSequenceClassification), - (ElectraConfig, FlaxElectraForSequenceClassification), - (MBartConfig, FlaxMBartForSequenceClassification), + ("roberta", "FlaxRobertaForSequenceClassification"), + ("bert", "FlaxBertForSequenceClassification"), + ("big_bird", "FlaxBigBirdForSequenceClassification"), + ("bart", "FlaxBartForSequenceClassification"), + ("electra", "FlaxElectraForSequenceClassification"), + ("mbart", "FlaxMBartForSequenceClassification"), ] ) -FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( +FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - (RobertaConfig, FlaxRobertaForQuestionAnswering), - (BertConfig, FlaxBertForQuestionAnswering), - (BigBirdConfig, FlaxBigBirdForQuestionAnswering), - (BartConfig, FlaxBartForQuestionAnswering), - (ElectraConfig, FlaxElectraForQuestionAnswering), - (MBartConfig, FlaxMBartForQuestionAnswering), + ("roberta", "FlaxRobertaForQuestionAnswering"), + ("bert", "FlaxBertForQuestionAnswering"), + ("big_bird", "FlaxBigBirdForQuestionAnswering"), + ("bart", "FlaxBartForQuestionAnswering"), + ("electra", "FlaxElectraForQuestionAnswering"), + ("mbart", "FlaxMBartForQuestionAnswering"), ] ) -FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( +FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - (RobertaConfig, FlaxRobertaForTokenClassification), - (BertConfig, FlaxBertForTokenClassification), - (BigBirdConfig, FlaxBigBirdForTokenClassification), - (ElectraConfig, FlaxElectraForTokenClassification), + ("roberta", "FlaxRobertaForTokenClassification"), + ("bert", "FlaxBertForTokenClassification"), + ("big_bird", "FlaxBigBirdForTokenClassification"), + ("electra", "FlaxElectraForTokenClassification"), ] ) -FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( +FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - (RobertaConfig, FlaxRobertaForMultipleChoice), - (BertConfig, FlaxBertForMultipleChoice), - (BigBirdConfig, FlaxBigBirdForMultipleChoice), - (ElectraConfig, FlaxElectraForMultipleChoice), + ("roberta", "FlaxRobertaForMultipleChoice"), + ("bert", "FlaxBertForMultipleChoice"), + ("big_bird", "FlaxBigBirdForMultipleChoice"), + ("electra", "FlaxElectraForMultipleChoice"), ] ) -FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( +FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ - (BertConfig, FlaxBertForNextSentencePrediction), + ("bert", "FlaxBertForNextSentencePrediction"), ] ) +FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES) +FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES) +FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES) +FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) +FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES +) +FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) + + class FlaxAutoModel(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_MAPPING diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index b1968b58cfd..8e9745763fb 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -19,492 +19,298 @@ import warnings from collections import OrderedDict from ...utils import logging - -# Add modeling imports here -from ..albert.modeling_tf_albert import ( - TFAlbertForMaskedLM, - TFAlbertForMultipleChoice, - TFAlbertForPreTraining, - TFAlbertForQuestionAnswering, - TFAlbertForSequenceClassification, - TFAlbertForTokenClassification, - TFAlbertModel, -) -from ..bart.modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel -from ..bert.modeling_tf_bert import ( - TFBertForMaskedLM, - TFBertForMultipleChoice, - TFBertForNextSentencePrediction, - TFBertForPreTraining, - TFBertForQuestionAnswering, - TFBertForSequenceClassification, - TFBertForTokenClassification, - TFBertLMHeadModel, - TFBertModel, -) -from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel -from ..blenderbot_small.modeling_tf_blenderbot_small import ( - TFBlenderbotSmallForConditionalGeneration, - TFBlenderbotSmallModel, -) -from ..camembert.modeling_tf_camembert import ( - TFCamembertForMaskedLM, - TFCamembertForMultipleChoice, - TFCamembertForQuestionAnswering, - TFCamembertForSequenceClassification, - TFCamembertForTokenClassification, - TFCamembertModel, -) -from ..convbert.modeling_tf_convbert import ( - TFConvBertForMaskedLM, - TFConvBertForMultipleChoice, - TFConvBertForQuestionAnswering, - TFConvBertForSequenceClassification, - TFConvBertForTokenClassification, - TFConvBertModel, -) -from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel -from ..distilbert.modeling_tf_distilbert import ( - TFDistilBertForMaskedLM, - TFDistilBertForMultipleChoice, - TFDistilBertForQuestionAnswering, - TFDistilBertForSequenceClassification, - TFDistilBertForTokenClassification, - TFDistilBertModel, -) -from ..dpr.modeling_tf_dpr import TFDPRQuestionEncoder -from ..electra.modeling_tf_electra import ( - TFElectraForMaskedLM, - TFElectraForMultipleChoice, - TFElectraForPreTraining, - TFElectraForQuestionAnswering, - TFElectraForSequenceClassification, - TFElectraForTokenClassification, - TFElectraModel, -) -from ..flaubert.modeling_tf_flaubert import ( - TFFlaubertForMultipleChoice, - TFFlaubertForQuestionAnsweringSimple, - TFFlaubertForSequenceClassification, - TFFlaubertForTokenClassification, - TFFlaubertModel, - TFFlaubertWithLMHeadModel, -) -from ..funnel.modeling_tf_funnel import ( - TFFunnelBaseModel, - TFFunnelForMaskedLM, - TFFunnelForMultipleChoice, - TFFunnelForPreTraining, - TFFunnelForQuestionAnswering, - TFFunnelForSequenceClassification, - TFFunnelForTokenClassification, - TFFunnelModel, -) -from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model -from ..hubert.modeling_tf_hubert import TFHubertModel -from ..layoutlm.modeling_tf_layoutlm import ( - TFLayoutLMForMaskedLM, - TFLayoutLMForSequenceClassification, - TFLayoutLMForTokenClassification, - TFLayoutLMModel, -) -from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel -from ..longformer.modeling_tf_longformer import ( - TFLongformerForMaskedLM, - TFLongformerForMultipleChoice, - TFLongformerForQuestionAnswering, - TFLongformerForSequenceClassification, - TFLongformerForTokenClassification, - TFLongformerModel, -) -from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel -from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel -from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel -from ..mobilebert.modeling_tf_mobilebert import ( - TFMobileBertForMaskedLM, - TFMobileBertForMultipleChoice, - TFMobileBertForNextSentencePrediction, - TFMobileBertForPreTraining, - TFMobileBertForQuestionAnswering, - TFMobileBertForSequenceClassification, - TFMobileBertForTokenClassification, - TFMobileBertModel, -) -from ..mpnet.modeling_tf_mpnet import ( - TFMPNetForMaskedLM, - TFMPNetForMultipleChoice, - TFMPNetForQuestionAnswering, - TFMPNetForSequenceClassification, - TFMPNetForTokenClassification, - TFMPNetModel, -) -from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model -from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel -from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel -from ..rembert.modeling_tf_rembert import ( - TFRemBertForCausalLM, - TFRemBertForMaskedLM, - TFRemBertForMultipleChoice, - TFRemBertForQuestionAnswering, - TFRemBertForSequenceClassification, - TFRemBertForTokenClassification, - TFRemBertModel, -) -from ..roberta.modeling_tf_roberta import ( - TFRobertaForMaskedLM, - TFRobertaForMultipleChoice, - TFRobertaForQuestionAnswering, - TFRobertaForSequenceClassification, - TFRobertaForTokenClassification, - TFRobertaModel, -) -from ..roformer.modeling_tf_roformer import ( - TFRoFormerForCausalLM, - TFRoFormerForMaskedLM, - TFRoFormerForMultipleChoice, - TFRoFormerForQuestionAnswering, - TFRoFormerForSequenceClassification, - TFRoFormerForTokenClassification, - TFRoFormerModel, -) -from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model -from ..transfo_xl.modeling_tf_transfo_xl import ( - TFTransfoXLForSequenceClassification, - TFTransfoXLLMHeadModel, - TFTransfoXLModel, -) -from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model -from ..xlm.modeling_tf_xlm import ( - TFXLMForMultipleChoice, - TFXLMForQuestionAnsweringSimple, - TFXLMForSequenceClassification, - TFXLMForTokenClassification, - TFXLMModel, - TFXLMWithLMHeadModel, -) -from ..xlm_roberta.modeling_tf_xlm_roberta import ( - TFXLMRobertaForMaskedLM, - TFXLMRobertaForMultipleChoice, - TFXLMRobertaForQuestionAnswering, - TFXLMRobertaForSequenceClassification, - TFXLMRobertaForTokenClassification, - TFXLMRobertaModel, -) -from ..xlnet.modeling_tf_xlnet import ( - TFXLNetForMultipleChoice, - TFXLNetForQuestionAnsweringSimple, - TFXLNetForSequenceClassification, - TFXLNetForTokenClassification, - TFXLNetLMHeadModel, - TFXLNetModel, -) -from .auto_factory import _BaseAutoModelClass, auto_class_update -from .configuration_auto import ( - AlbertConfig, - BartConfig, - BertConfig, - BlenderbotConfig, - BlenderbotSmallConfig, - CamembertConfig, - ConvBertConfig, - CTRLConfig, - DistilBertConfig, - DPRConfig, - ElectraConfig, - FlaubertConfig, - FunnelConfig, - GPT2Config, - HubertConfig, - LayoutLMConfig, - LEDConfig, - LongformerConfig, - LxmertConfig, - MarianConfig, - MBartConfig, - MobileBertConfig, - MPNetConfig, - MT5Config, - OpenAIGPTConfig, - PegasusConfig, - RemBertConfig, - RobertaConfig, - RoFormerConfig, - T5Config, - TransfoXLConfig, - Wav2Vec2Config, - XLMConfig, - XLMRobertaConfig, - XLNetConfig, -) +from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .configuration_auto import CONFIG_MAPPING_NAMES logger = logging.get_logger(__name__) -TF_MODEL_MAPPING = OrderedDict( +TF_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping - (RemBertConfig, TFRemBertModel), - (RoFormerConfig, TFRoFormerModel), - (ConvBertConfig, TFConvBertModel), - (LEDConfig, TFLEDModel), - (LxmertConfig, TFLxmertModel), - (MT5Config, TFMT5Model), - (T5Config, TFT5Model), - (DistilBertConfig, TFDistilBertModel), - (AlbertConfig, TFAlbertModel), - (BartConfig, TFBartModel), - (CamembertConfig, TFCamembertModel), - (XLMRobertaConfig, TFXLMRobertaModel), - (LongformerConfig, TFLongformerModel), - (RobertaConfig, TFRobertaModel), - (LayoutLMConfig, TFLayoutLMModel), - (BertConfig, TFBertModel), - (OpenAIGPTConfig, TFOpenAIGPTModel), - (GPT2Config, TFGPT2Model), - (MobileBertConfig, TFMobileBertModel), - (TransfoXLConfig, TFTransfoXLModel), - (XLNetConfig, TFXLNetModel), - (FlaubertConfig, TFFlaubertModel), - (XLMConfig, TFXLMModel), - (CTRLConfig, TFCTRLModel), - (ElectraConfig, TFElectraModel), - (FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)), - (DPRConfig, TFDPRQuestionEncoder), - (MPNetConfig, TFMPNetModel), - (BartConfig, TFBartModel), - (MBartConfig, TFMBartModel), - (MarianConfig, TFMarianModel), - (PegasusConfig, TFPegasusModel), - (BlenderbotConfig, TFBlenderbotModel), - (BlenderbotSmallConfig, TFBlenderbotSmallModel), - (Wav2Vec2Config, TFWav2Vec2Model), - (HubertConfig, TFHubertModel), + ("rembert", "TFRemBertModel"), + ("roformer", "TFRoFormerModel"), + ("convbert", "TFConvBertModel"), + ("led", "TFLEDModel"), + ("lxmert", "TFLxmertModel"), + ("mt5", "TFMT5Model"), + ("t5", "TFT5Model"), + ("distilbert", "TFDistilBertModel"), + ("albert", "TFAlbertModel"), + ("bart", "TFBartModel"), + ("camembert", "TFCamembertModel"), + ("xlm-roberta", "TFXLMRobertaModel"), + ("longformer", "TFLongformerModel"), + ("roberta", "TFRobertaModel"), + ("layoutlm", "TFLayoutLMModel"), + ("bert", "TFBertModel"), + ("openai-gpt", "TFOpenAIGPTModel"), + ("gpt2", "TFGPT2Model"), + ("mobilebert", "TFMobileBertModel"), + ("transfo-xl", "TFTransfoXLModel"), + ("xlnet", "TFXLNetModel"), + ("flaubert", "TFFlaubertModel"), + ("xlm", "TFXLMModel"), + ("ctrl", "TFCTRLModel"), + ("electra", "TFElectraModel"), + ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), + ("dpr", "TFDPRQuestionEncoder"), + ("mpnet", "TFMPNetModel"), + ("mbart", "TFMBartModel"), + ("marian", "TFMarianModel"), + ("pegasus", "TFPegasusModel"), + ("blenderbot", "TFBlenderbotModel"), + ("blenderbot-small", "TFBlenderbotSmallModel"), + ("wav2vec2", "TFWav2Vec2Model"), + ("hubert", "TFHubertModel"), ] ) -TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( +TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping - (LxmertConfig, TFLxmertForPreTraining), - (T5Config, TFT5ForConditionalGeneration), - (DistilBertConfig, TFDistilBertForMaskedLM), - (AlbertConfig, TFAlbertForPreTraining), - (BartConfig, TFBartForConditionalGeneration), - (CamembertConfig, TFCamembertForMaskedLM), - (XLMRobertaConfig, TFXLMRobertaForMaskedLM), - (RobertaConfig, TFRobertaForMaskedLM), - (LayoutLMConfig, TFLayoutLMForMaskedLM), - (BertConfig, TFBertForPreTraining), - (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), - (GPT2Config, TFGPT2LMHeadModel), - (MobileBertConfig, TFMobileBertForPreTraining), - (TransfoXLConfig, TFTransfoXLLMHeadModel), - (XLNetConfig, TFXLNetLMHeadModel), - (FlaubertConfig, TFFlaubertWithLMHeadModel), - (XLMConfig, TFXLMWithLMHeadModel), - (CTRLConfig, TFCTRLLMHeadModel), - (ElectraConfig, TFElectraForPreTraining), - (FunnelConfig, TFFunnelForPreTraining), - (MPNetConfig, TFMPNetForMaskedLM), + ("lxmert", "TFLxmertForPreTraining"), + ("t5", "TFT5ForConditionalGeneration"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("albert", "TFAlbertForPreTraining"), + ("bart", "TFBartForConditionalGeneration"), + ("camembert", "TFCamembertForMaskedLM"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("bert", "TFBertForPreTraining"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("gpt2", "TFGPT2LMHeadModel"), + ("mobilebert", "TFMobileBertForPreTraining"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("xlnet", "TFXLNetLMHeadModel"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("xlm", "TFXLMWithLMHeadModel"), + ("ctrl", "TFCTRLLMHeadModel"), + ("electra", "TFElectraForPreTraining"), + ("funnel", "TFFunnelForPreTraining"), + ("mpnet", "TFMPNetForMaskedLM"), ] ) -TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( +TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping - (RemBertConfig, TFRemBertForMaskedLM), - (RoFormerConfig, TFRoFormerForMaskedLM), - (ConvBertConfig, TFConvBertForMaskedLM), - (LEDConfig, TFLEDForConditionalGeneration), - (T5Config, TFT5ForConditionalGeneration), - (DistilBertConfig, TFDistilBertForMaskedLM), - (AlbertConfig, TFAlbertForMaskedLM), - (MarianConfig, TFMarianMTModel), - (BartConfig, TFBartForConditionalGeneration), - (CamembertConfig, TFCamembertForMaskedLM), - (XLMRobertaConfig, TFXLMRobertaForMaskedLM), - (LongformerConfig, TFLongformerForMaskedLM), - (RobertaConfig, TFRobertaForMaskedLM), - (LayoutLMConfig, TFLayoutLMForMaskedLM), - (BertConfig, TFBertForMaskedLM), - (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), - (GPT2Config, TFGPT2LMHeadModel), - (MobileBertConfig, TFMobileBertForMaskedLM), - (TransfoXLConfig, TFTransfoXLLMHeadModel), - (XLNetConfig, TFXLNetLMHeadModel), - (FlaubertConfig, TFFlaubertWithLMHeadModel), - (XLMConfig, TFXLMWithLMHeadModel), - (CTRLConfig, TFCTRLLMHeadModel), - (ElectraConfig, TFElectraForMaskedLM), - (FunnelConfig, TFFunnelForMaskedLM), - (MPNetConfig, TFMPNetForMaskedLM), + ("rembert", "TFRemBertForMaskedLM"), + ("roformer", "TFRoFormerForMaskedLM"), + ("convbert", "TFConvBertForMaskedLM"), + ("led", "TFLEDForConditionalGeneration"), + ("t5", "TFT5ForConditionalGeneration"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("albert", "TFAlbertForMaskedLM"), + ("marian", "TFMarianMTModel"), + ("bart", "TFBartForConditionalGeneration"), + ("camembert", "TFCamembertForMaskedLM"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("longformer", "TFLongformerForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("bert", "TFBertForMaskedLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("gpt2", "TFGPT2LMHeadModel"), + ("mobilebert", "TFMobileBertForMaskedLM"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("xlnet", "TFXLNetLMHeadModel"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("xlm", "TFXLMWithLMHeadModel"), + ("ctrl", "TFCTRLLMHeadModel"), + ("electra", "TFElectraForMaskedLM"), + ("funnel", "TFFunnelForMaskedLM"), + ("mpnet", "TFMPNetForMaskedLM"), ] ) -TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( +TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - (RemBertConfig, TFRemBertForCausalLM), - (RoFormerConfig, TFRoFormerForCausalLM), - (BertConfig, TFBertLMHeadModel), - (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel), - (GPT2Config, TFGPT2LMHeadModel), - (TransfoXLConfig, TFTransfoXLLMHeadModel), - (XLNetConfig, TFXLNetLMHeadModel), - ( - XLMConfig, - TFXLMWithLMHeadModel, - ), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now - (CTRLConfig, TFCTRLLMHeadModel), + ("rembert", "TFRemBertForCausalLM"), + ("roformer", "TFRoFormerForCausalLM"), + ("bert", "TFBertLMHeadModel"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("gpt2", "TFGPT2LMHeadModel"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("xlnet", "TFXLNetLMHeadModel"), + ("xlm", "TFXLMWithLMHeadModel"), + ("ctrl", "TFCTRLLMHeadModel"), ] ) -TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( +TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - (RemBertConfig, TFRemBertForMaskedLM), - (RoFormerConfig, TFRoFormerForMaskedLM), - (ConvBertConfig, TFConvBertForMaskedLM), - (DistilBertConfig, TFDistilBertForMaskedLM), - (AlbertConfig, TFAlbertForMaskedLM), - (CamembertConfig, TFCamembertForMaskedLM), - (XLMRobertaConfig, TFXLMRobertaForMaskedLM), - (LongformerConfig, TFLongformerForMaskedLM), - (RobertaConfig, TFRobertaForMaskedLM), - (LayoutLMConfig, TFLayoutLMForMaskedLM), - (BertConfig, TFBertForMaskedLM), - (MobileBertConfig, TFMobileBertForMaskedLM), - (FlaubertConfig, TFFlaubertWithLMHeadModel), - (XLMConfig, TFXLMWithLMHeadModel), - (ElectraConfig, TFElectraForMaskedLM), - (FunnelConfig, TFFunnelForMaskedLM), - (MPNetConfig, TFMPNetForMaskedLM), + ("rembert", "TFRemBertForMaskedLM"), + ("roformer", "TFRoFormerForMaskedLM"), + ("convbert", "TFConvBertForMaskedLM"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("albert", "TFAlbertForMaskedLM"), + ("camembert", "TFCamembertForMaskedLM"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("longformer", "TFLongformerForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("bert", "TFBertForMaskedLM"), + ("mobilebert", "TFMobileBertForMaskedLM"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("xlm", "TFXLMWithLMHeadModel"), + ("electra", "TFElectraForMaskedLM"), + ("funnel", "TFFunnelForMaskedLM"), + ("mpnet", "TFMPNetForMaskedLM"), ] ) -TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( +TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - (LEDConfig, TFLEDForConditionalGeneration), - (MT5Config, TFMT5ForConditionalGeneration), - (T5Config, TFT5ForConditionalGeneration), - (MarianConfig, TFMarianMTModel), - (MBartConfig, TFMBartForConditionalGeneration), - (PegasusConfig, TFPegasusForConditionalGeneration), - (BlenderbotConfig, TFBlenderbotForConditionalGeneration), - (BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration), - (BartConfig, TFBartForConditionalGeneration), + ("led", "TFLEDForConditionalGeneration"), + ("mt5", "TFMT5ForConditionalGeneration"), + ("t5", "TFT5ForConditionalGeneration"), + ("marian", "TFMarianMTModel"), + ("mbart", "TFMBartForConditionalGeneration"), + ("pegasus", "TFPegasusForConditionalGeneration"), + ("blenderbot", "TFBlenderbotForConditionalGeneration"), + ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), + ("bart", "TFBartForConditionalGeneration"), ] ) -TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( +TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - (RemBertConfig, TFRemBertForSequenceClassification), - (RoFormerConfig, TFRoFormerForSequenceClassification), - (ConvBertConfig, TFConvBertForSequenceClassification), - (DistilBertConfig, TFDistilBertForSequenceClassification), - (AlbertConfig, TFAlbertForSequenceClassification), - (CamembertConfig, TFCamembertForSequenceClassification), - (XLMRobertaConfig, TFXLMRobertaForSequenceClassification), - (LongformerConfig, TFLongformerForSequenceClassification), - (RobertaConfig, TFRobertaForSequenceClassification), - (LayoutLMConfig, TFLayoutLMForSequenceClassification), - (BertConfig, TFBertForSequenceClassification), - (XLNetConfig, TFXLNetForSequenceClassification), - (MobileBertConfig, TFMobileBertForSequenceClassification), - (FlaubertConfig, TFFlaubertForSequenceClassification), - (XLMConfig, TFXLMForSequenceClassification), - (ElectraConfig, TFElectraForSequenceClassification), - (FunnelConfig, TFFunnelForSequenceClassification), - (GPT2Config, TFGPT2ForSequenceClassification), - (MPNetConfig, TFMPNetForSequenceClassification), - (OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification), - (TransfoXLConfig, TFTransfoXLForSequenceClassification), - (CTRLConfig, TFCTRLForSequenceClassification), + ("rembert", "TFRemBertForSequenceClassification"), + ("roformer", "TFRoFormerForSequenceClassification"), + ("convbert", "TFConvBertForSequenceClassification"), + ("distilbert", "TFDistilBertForSequenceClassification"), + ("albert", "TFAlbertForSequenceClassification"), + ("camembert", "TFCamembertForSequenceClassification"), + ("xlm-roberta", "TFXLMRobertaForSequenceClassification"), + ("longformer", "TFLongformerForSequenceClassification"), + ("roberta", "TFRobertaForSequenceClassification"), + ("layoutlm", "TFLayoutLMForSequenceClassification"), + ("bert", "TFBertForSequenceClassification"), + ("xlnet", "TFXLNetForSequenceClassification"), + ("mobilebert", "TFMobileBertForSequenceClassification"), + ("flaubert", "TFFlaubertForSequenceClassification"), + ("xlm", "TFXLMForSequenceClassification"), + ("electra", "TFElectraForSequenceClassification"), + ("funnel", "TFFunnelForSequenceClassification"), + ("gpt2", "TFGPT2ForSequenceClassification"), + ("mpnet", "TFMPNetForSequenceClassification"), + ("openai-gpt", "TFOpenAIGPTForSequenceClassification"), + ("transfo-xl", "TFTransfoXLForSequenceClassification"), + ("ctrl", "TFCTRLForSequenceClassification"), ] ) -TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( +TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - (RemBertConfig, TFRemBertForQuestionAnswering), - (RoFormerConfig, TFRoFormerForQuestionAnswering), - (ConvBertConfig, TFConvBertForQuestionAnswering), - (DistilBertConfig, TFDistilBertForQuestionAnswering), - (AlbertConfig, TFAlbertForQuestionAnswering), - (CamembertConfig, TFCamembertForQuestionAnswering), - (XLMRobertaConfig, TFXLMRobertaForQuestionAnswering), - (LongformerConfig, TFLongformerForQuestionAnswering), - (RobertaConfig, TFRobertaForQuestionAnswering), - (BertConfig, TFBertForQuestionAnswering), - (XLNetConfig, TFXLNetForQuestionAnsweringSimple), - (MobileBertConfig, TFMobileBertForQuestionAnswering), - (FlaubertConfig, TFFlaubertForQuestionAnsweringSimple), - (XLMConfig, TFXLMForQuestionAnsweringSimple), - (ElectraConfig, TFElectraForQuestionAnswering), - (FunnelConfig, TFFunnelForQuestionAnswering), - (MPNetConfig, TFMPNetForQuestionAnswering), + ("rembert", "TFRemBertForQuestionAnswering"), + ("roformer", "TFRoFormerForQuestionAnswering"), + ("convbert", "TFConvBertForQuestionAnswering"), + ("distilbert", "TFDistilBertForQuestionAnswering"), + ("albert", "TFAlbertForQuestionAnswering"), + ("camembert", "TFCamembertForQuestionAnswering"), + ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"), + ("longformer", "TFLongformerForQuestionAnswering"), + ("roberta", "TFRobertaForQuestionAnswering"), + ("bert", "TFBertForQuestionAnswering"), + ("xlnet", "TFXLNetForQuestionAnsweringSimple"), + ("mobilebert", "TFMobileBertForQuestionAnswering"), + ("flaubert", "TFFlaubertForQuestionAnsweringSimple"), + ("xlm", "TFXLMForQuestionAnsweringSimple"), + ("electra", "TFElectraForQuestionAnswering"), + ("funnel", "TFFunnelForQuestionAnswering"), + ("mpnet", "TFMPNetForQuestionAnswering"), ] ) -TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( +TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - (RemBertConfig, TFRemBertForTokenClassification), - (RoFormerConfig, TFRoFormerForTokenClassification), - (ConvBertConfig, TFConvBertForTokenClassification), - (DistilBertConfig, TFDistilBertForTokenClassification), - (AlbertConfig, TFAlbertForTokenClassification), - (CamembertConfig, TFCamembertForTokenClassification), - (FlaubertConfig, TFFlaubertForTokenClassification), - (XLMConfig, TFXLMForTokenClassification), - (XLMRobertaConfig, TFXLMRobertaForTokenClassification), - (LongformerConfig, TFLongformerForTokenClassification), - (RobertaConfig, TFRobertaForTokenClassification), - (LayoutLMConfig, TFLayoutLMForTokenClassification), - (BertConfig, TFBertForTokenClassification), - (MobileBertConfig, TFMobileBertForTokenClassification), - (XLNetConfig, TFXLNetForTokenClassification), - (ElectraConfig, TFElectraForTokenClassification), - (FunnelConfig, TFFunnelForTokenClassification), - (MPNetConfig, TFMPNetForTokenClassification), + ("rembert", "TFRemBertForTokenClassification"), + ("roformer", "TFRoFormerForTokenClassification"), + ("convbert", "TFConvBertForTokenClassification"), + ("distilbert", "TFDistilBertForTokenClassification"), + ("albert", "TFAlbertForTokenClassification"), + ("camembert", "TFCamembertForTokenClassification"), + ("flaubert", "TFFlaubertForTokenClassification"), + ("xlm", "TFXLMForTokenClassification"), + ("xlm-roberta", "TFXLMRobertaForTokenClassification"), + ("longformer", "TFLongformerForTokenClassification"), + ("roberta", "TFRobertaForTokenClassification"), + ("layoutlm", "TFLayoutLMForTokenClassification"), + ("bert", "TFBertForTokenClassification"), + ("mobilebert", "TFMobileBertForTokenClassification"), + ("xlnet", "TFXLNetForTokenClassification"), + ("electra", "TFElectraForTokenClassification"), + ("funnel", "TFFunnelForTokenClassification"), + ("mpnet", "TFMPNetForTokenClassification"), ] ) -TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( +TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - (RemBertConfig, TFRemBertForMultipleChoice), - (RoFormerConfig, TFRoFormerForMultipleChoice), - (ConvBertConfig, TFConvBertForMultipleChoice), - (CamembertConfig, TFCamembertForMultipleChoice), - (XLMConfig, TFXLMForMultipleChoice), - (XLMRobertaConfig, TFXLMRobertaForMultipleChoice), - (LongformerConfig, TFLongformerForMultipleChoice), - (RobertaConfig, TFRobertaForMultipleChoice), - (BertConfig, TFBertForMultipleChoice), - (DistilBertConfig, TFDistilBertForMultipleChoice), - (MobileBertConfig, TFMobileBertForMultipleChoice), - (XLNetConfig, TFXLNetForMultipleChoice), - (FlaubertConfig, TFFlaubertForMultipleChoice), - (AlbertConfig, TFAlbertForMultipleChoice), - (ElectraConfig, TFElectraForMultipleChoice), - (FunnelConfig, TFFunnelForMultipleChoice), - (MPNetConfig, TFMPNetForMultipleChoice), + ("rembert", "TFRemBertForMultipleChoice"), + ("roformer", "TFRoFormerForMultipleChoice"), + ("convbert", "TFConvBertForMultipleChoice"), + ("camembert", "TFCamembertForMultipleChoice"), + ("xlm", "TFXLMForMultipleChoice"), + ("xlm-roberta", "TFXLMRobertaForMultipleChoice"), + ("longformer", "TFLongformerForMultipleChoice"), + ("roberta", "TFRobertaForMultipleChoice"), + ("bert", "TFBertForMultipleChoice"), + ("distilbert", "TFDistilBertForMultipleChoice"), + ("mobilebert", "TFMobileBertForMultipleChoice"), + ("xlnet", "TFXLNetForMultipleChoice"), + ("flaubert", "TFFlaubertForMultipleChoice"), + ("albert", "TFAlbertForMultipleChoice"), + ("electra", "TFElectraForMultipleChoice"), + ("funnel", "TFFunnelForMultipleChoice"), + ("mpnet", "TFMPNetForMultipleChoice"), ] ) -TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( +TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ - (BertConfig, TFBertForNextSentencePrediction), - (MobileBertConfig, TFMobileBertForNextSentencePrediction), + ("bert", "TFBertForNextSentencePrediction"), + ("mobilebert", "TFMobileBertForNextSentencePrediction"), ] ) +TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES) +TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) +TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) +TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) +TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES +) +TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES +) +TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +) +TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES +) +TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES +) +TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES +) + + class TFAutoModel(_BaseAutoModelClass): _model_mapping = TF_MODEL_MAPPING diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 254dd34e81e..6a4ccdc07fc 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -14,12 +14,12 @@ # limitations under the License. """ Auto Tokenizer class. """ +import importlib import json import os from collections import OrderedDict from typing import Dict, Optional, Union -from ... import GPTNeoConfig from ...configuration_utils import PretrainedConfig from ...file_utils import ( cached_path, @@ -29,315 +29,183 @@ from ...file_utils import ( is_tokenizers_available, ) from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE +from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging -from ..bart.tokenization_bart import BartTokenizer -from ..bert.tokenization_bert import BertTokenizer -from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer -from ..bertweet.tokenization_bertweet import BertweetTokenizer -from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer -from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer -from ..byt5.tokenization_byt5 import ByT5Tokenizer -from ..canine.tokenization_canine import CanineTokenizer -from ..convbert.tokenization_convbert import ConvBertTokenizer -from ..ctrl.tokenization_ctrl import CTRLTokenizer -from ..deberta.tokenization_deberta import DebertaTokenizer -from ..distilbert.tokenization_distilbert import DistilBertTokenizer -from ..dpr.tokenization_dpr import DPRQuestionEncoderTokenizer -from ..electra.tokenization_electra import ElectraTokenizer -from ..flaubert.tokenization_flaubert import FlaubertTokenizer -from ..fsmt.tokenization_fsmt import FSMTTokenizer -from ..funnel.tokenization_funnel import FunnelTokenizer -from ..gpt2.tokenization_gpt2 import GPT2Tokenizer -from ..herbert.tokenization_herbert import HerbertTokenizer -from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer -from ..led.tokenization_led import LEDTokenizer -from ..longformer.tokenization_longformer import LongformerTokenizer -from ..luke.tokenization_luke import LukeTokenizer -from ..lxmert.tokenization_lxmert import LxmertTokenizer -from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer -from ..mpnet.tokenization_mpnet import MPNetTokenizer -from ..openai.tokenization_openai import OpenAIGPTTokenizer -from ..phobert.tokenization_phobert import PhobertTokenizer -from ..prophetnet.tokenization_prophetnet import ProphetNetTokenizer -from ..rag.tokenization_rag import RagTokenizer -from ..retribert.tokenization_retribert import RetriBertTokenizer -from ..roberta.tokenization_roberta import RobertaTokenizer -from ..roformer.tokenization_roformer import RoFormerTokenizer -from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer -from ..tapas.tokenization_tapas import TapasTokenizer -from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer -from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer -from ..xlm.tokenization_xlm import XLMTokenizer +from ..encoder_decoder import EncoderDecoderConfig +from .auto_factory import _LazyAutoMapping from .configuration_auto import ( - AlbertConfig, + CONFIG_MAPPING_NAMES, AutoConfig, - BartConfig, - BertConfig, - BertGenerationConfig, - BigBirdConfig, - BigBirdPegasusConfig, - BlenderbotConfig, - BlenderbotSmallConfig, - CamembertConfig, - CanineConfig, - ConvBertConfig, - CTRLConfig, - DebertaConfig, - DebertaV2Config, - DistilBertConfig, - DPRConfig, - ElectraConfig, - EncoderDecoderConfig, - FlaubertConfig, - FSMTConfig, - FunnelConfig, - GPT2Config, - HubertConfig, - IBertConfig, - LayoutLMConfig, - LEDConfig, - LongformerConfig, - LukeConfig, - LxmertConfig, - M2M100Config, - MarianConfig, - MBartConfig, - MobileBertConfig, - MPNetConfig, - MT5Config, - OpenAIGPTConfig, - PegasusConfig, - ProphetNetConfig, - RagConfig, - ReformerConfig, - RetriBertConfig, - RobertaConfig, - RoFormerConfig, - Speech2TextConfig, - SqueezeBertConfig, - T5Config, - TapasConfig, - TransfoXLConfig, - Wav2Vec2Config, - XLMConfig, - XLMProphetNetConfig, - XLMRobertaConfig, - XLNetConfig, + config_class_to_model_type, replace_list_option_in_docstrings, ) -if is_sentencepiece_available(): - from ..albert.tokenization_albert import AlbertTokenizer - from ..barthez.tokenization_barthez import BarthezTokenizer - from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer - from ..big_bird.tokenization_big_bird import BigBirdTokenizer - from ..camembert.tokenization_camembert import CamembertTokenizer - from ..cpm.tokenization_cpm import CpmTokenizer - from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer - from ..m2m_100 import M2M100Tokenizer - from ..marian.tokenization_marian import MarianTokenizer - from ..mbart.tokenization_mbart import MBartTokenizer - from ..mbart.tokenization_mbart50 import MBart50Tokenizer - from ..mt5 import MT5Tokenizer - from ..pegasus.tokenization_pegasus import PegasusTokenizer - from ..reformer.tokenization_reformer import ReformerTokenizer - from ..speech_to_text import Speech2TextTokenizer - from ..t5.tokenization_t5 import T5Tokenizer - from ..xlm_prophetnet.tokenization_xlm_prophetnet import XLMProphetNetTokenizer - from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer - from ..xlnet.tokenization_xlnet import XLNetTokenizer -else: - AlbertTokenizer = None - BarthezTokenizer = None - BertGenerationTokenizer = None - BigBirdTokenizer = None - CamembertTokenizer = None - CpmTokenizer = None - DebertaV2Tokenizer = None - MarianTokenizer = None - MBartTokenizer = None - MBart50Tokenizer = None - MT5Tokenizer = None - PegasusTokenizer = None - ReformerTokenizer = None - T5Tokenizer = None - XLMRobertaTokenizer = None - XLNetTokenizer = None - XLMProphetNetTokenizer = None - M2M100Tokenizer = None - Speech2TextTokenizer = None - -if is_tokenizers_available(): - from ...tokenization_utils_fast import PreTrainedTokenizerFast - from ..albert.tokenization_albert_fast import AlbertTokenizerFast - from ..bart.tokenization_bart_fast import BartTokenizerFast - from ..barthez.tokenization_barthez_fast import BarthezTokenizerFast - from ..bert.tokenization_bert_fast import BertTokenizerFast - from ..big_bird.tokenization_big_bird_fast import BigBirdTokenizerFast - from ..camembert.tokenization_camembert_fast import CamembertTokenizerFast - from ..convbert.tokenization_convbert_fast import ConvBertTokenizerFast - from ..cpm.tokenization_cpm_fast import CpmTokenizerFast - from ..deberta.tokenization_deberta_fast import DebertaTokenizerFast - from ..distilbert.tokenization_distilbert_fast import DistilBertTokenizerFast - from ..dpr.tokenization_dpr_fast import DPRQuestionEncoderTokenizerFast - from ..electra.tokenization_electra_fast import ElectraTokenizerFast - from ..funnel.tokenization_funnel_fast import FunnelTokenizerFast - from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast - from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast - from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast - from ..led.tokenization_led_fast import LEDTokenizerFast - from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast - from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast - from ..mbart.tokenization_mbart50_fast import MBart50TokenizerFast - from ..mbart.tokenization_mbart_fast import MBartTokenizerFast - from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast - from ..mpnet.tokenization_mpnet_fast import MPNetTokenizerFast - from ..mt5 import MT5TokenizerFast - from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast - from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast - from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast - from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast - from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast - from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast - from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast - from ..t5.tokenization_t5_fast import T5TokenizerFast - from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast - from ..xlnet.tokenization_xlnet_fast import XLNetTokenizerFast - -else: - AlbertTokenizerFast = None - BartTokenizerFast = None - BarthezTokenizerFast = None - BertTokenizerFast = None - BigBirdTokenizerFast = None - CamembertTokenizerFast = None - ConvBertTokenizerFast = None - CpmTokenizerFast = None - DebertaTokenizerFast = None - DistilBertTokenizerFast = None - DPRQuestionEncoderTokenizerFast = None - ElectraTokenizerFast = None - FunnelTokenizerFast = None - GPT2TokenizerFast = None - HerbertTokenizerFast = None - LayoutLMTokenizerFast = None - LEDTokenizerFast = None - LongformerTokenizerFast = None - LxmertTokenizerFast = None - MBartTokenizerFast = None - MBart50TokenizerFast = None - MobileBertTokenizerFast = None - MPNetTokenizerFast = None - MT5TokenizerFast = None - OpenAIGPTTokenizerFast = None - PegasusTokenizerFast = None - ReformerTokenizerFast = None - RetriBertTokenizerFast = None - RobertaTokenizerFast = None - RoFormerTokenizerFast = None - SqueezeBertTokenizerFast = None - T5TokenizerFast = None - XLMRobertaTokenizerFast = None - XLNetTokenizerFast = None - PreTrainedTokenizerFast = None - - logger = logging.get_logger(__name__) -TOKENIZER_MAPPING = OrderedDict( +TOKENIZER_MAPPING_NAMES = OrderedDict( [ - (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), - (RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)), - (T5Config, (T5Tokenizer, T5TokenizerFast)), - (MT5Config, (MT5Tokenizer, MT5TokenizerFast)), - (MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)), - (DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)), - (AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)), - (CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)), - (PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)), - (MBartConfig, (MBartTokenizer, MBartTokenizerFast)), - (XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)), - (MarianConfig, (MarianTokenizer, None)), - (BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)), - (BlenderbotConfig, (BlenderbotTokenizer, None)), - (BartConfig, (BartTokenizer, BartTokenizerFast)), - (LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)), - (RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)), - (ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)), - (ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)), - (FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)), - (LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)), - (LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)), - (DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)), - (SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)), - (BertConfig, (BertTokenizer, BertTokenizerFast)), - (OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)), - (GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)), - (TransfoXLConfig, (TransfoXLTokenizer, None)), - (XLNetConfig, (XLNetTokenizer, XLNetTokenizerFast)), - (FlaubertConfig, (FlaubertTokenizer, None)), - (XLMConfig, (XLMTokenizer, None)), - (CTRLConfig, (CTRLTokenizer, None)), - (FSMTConfig, (FSMTTokenizer, None)), - (BertGenerationConfig, (BertGenerationTokenizer, None)), - (DebertaConfig, (DebertaTokenizer, DebertaTokenizerFast)), - (DebertaV2Config, (DebertaV2Tokenizer, None)), - (RagConfig, (RagTokenizer, None)), - (XLMProphetNetConfig, (XLMProphetNetTokenizer, None)), - (Speech2TextConfig, (Speech2TextTokenizer, None)), - (M2M100Config, (M2M100Tokenizer, None)), - (ProphetNetConfig, (ProphetNetTokenizer, None)), - (MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)), - (TapasConfig, (TapasTokenizer, None)), - (LEDConfig, (LEDTokenizer, LEDTokenizerFast)), - (ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)), - (BigBirdConfig, (BigBirdTokenizer, BigBirdTokenizerFast)), - (IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)), - (Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)), - (HubertConfig, (Wav2Vec2CTCTokenizer, None)), - (GPTNeoConfig, (GPT2Tokenizer, GPT2TokenizerFast)), - (LukeConfig, (LukeTokenizer, None)), - (BigBirdPegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)), - (CanineConfig, (CanineTokenizer, None)), + ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), + ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), + ( + "t5", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "mt5", + ( + "MT5Tokenizer" if is_sentencepiece_available() else None, + "MT5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), + ( + "albert", + ( + "AlbertTokenizer" if is_sentencepiece_available() else None, + "AlbertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "camembert", + ( + "CamembertTokenizer" if is_sentencepiece_available() else None, + "CamembertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "pegasus", + ( + "PegasusTokenizer" if is_sentencepiece_available() else None, + "PegasusTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "mbart", + ( + "MBartTokenizer" if is_sentencepiece_available() else None, + "MBartTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "xlm-roberta", + ( + "XLMRobertaTokenizer" if is_sentencepiece_available() else None, + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), + ("blenderbot-small", ("BlenderbotSmallTokenizer", None)), + ("blenderbot", ("BlenderbotTokenizer", None)), + ("bart", ("BartTokenizer", "BartTokenizerFast")), + ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), + ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ( + "reformer", + ( + "ReformerTokenizer" if is_sentencepiece_available() else None, + "ReformerTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), + ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), + ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), + ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), + ( + "dpr", + ("DPRQuestionEncoderTokenizer", "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None), + ), + ("squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None)), + ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), + ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("transfo-xl", ("TransfoXLTokenizer", None)), + ( + "xlnet", + ( + "XLNetTokenizer" if is_sentencepiece_available() else None, + "XLNetTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("flaubert", ("FlaubertTokenizer", None)), + ("xlm", ("XLMTokenizer", None)), + ("ctrl", ("CTRLTokenizer", None)), + ("fsmt", ("FSMTTokenizer", None)), + ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)), + ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)), + ("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)), + ("rag", ("RagTokenizer", None)), + ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)), + ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), + ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), + ("prophetnet", ("ProphetNetTokenizer", None)), + ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), + ("tapas", ("TapasTokenizer", None)), + ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), + ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), + ( + "big_bird", + ( + "BigBirdTokenizer" if is_sentencepiece_available() else None, + "BigBirdTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), + ("hubert", ("Wav2Vec2CTCTokenizer", None)), + ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), + ("luke", ("LukeTokenizer", None)), + ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), + ("canine", ("CanineTokenizer", None)), + ("bertweet", ("BertweetTokenizer", None)), + ("bert-japanese", ("BertJapaneseTokenizer", None)), + ("byt5", ("ByT5Tokenizer", None)), + ( + "cpm", + ( + "CpmTokenizer" if is_sentencepiece_available() else None, + "CpmTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), + ("phobert", ("PhobertTokenizer", None)), + ( + "barthez", + ( + "BarthezTokenizer" if is_sentencepiece_available() else None, + "BarthezTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "mbart50", + ( + "MBart50Tokenizer" if is_sentencepiece_available() else None, + "MBart50TokenizerFast" if is_tokenizers_available() else None, + ), + ), ] ) -# For tokenizers which are not directly mapped from a config -NO_CONFIG_TOKENIZER = [ - BertJapaneseTokenizer, - BertweetTokenizer, - ByT5Tokenizer, - CpmTokenizer, - CpmTokenizerFast, - HerbertTokenizer, - HerbertTokenizerFast, - PhobertTokenizer, - BarthezTokenizer, - BarthezTokenizerFast, - MBart50Tokenizer, - MBart50TokenizerFast, - PreTrainedTokenizerFast, -] +TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES) - -SLOW_TOKENIZER_MAPPING = { - k: (v[0] if v[0] is not None else v[1]) - for k, v in TOKENIZER_MAPPING.items() - if (v[0] is not None or v[1] is not None) -} +CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()} def tokenizer_class_from_name(class_name: str): - all_tokenizer_classes = ( - [v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None] - + [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None] - + [v for v in NO_CONFIG_TOKENIZER if v is not None] - ) - for c in all_tokenizer_classes: - if c.__name__ == class_name: - return c + if class_name == "PreTrainedTokenizerFast": + return PreTrainedTokenizerFast + + for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): + if class_name in tokenizers: + break + + module = importlib.import_module(f".{module_name}", "transformers.models") + return getattr(module, class_name) def get_tokenizer_config( @@ -454,7 +322,7 @@ class AutoTokenizer: ) @classmethod - @replace_list_option_in_docstrings(SLOW_TOKENIZER_MAPPING) + @replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES) def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): r""" Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary. @@ -565,7 +433,8 @@ class AutoTokenizer: ) config = config.encoder - if type(config) in TOKENIZER_MAPPING.keys(): + model_type = config_class_to_model_type(type(config).__name__) + if model_type is not None: tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)] if tokenizer_class_fast and (use_fast or tokenizer_class_py is None): return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index 9ad6b6dfec5..f17ab91b279 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -33,10 +33,8 @@ _import_structure = { if is_sentencepiece_available(): _import_structure["tokenization_mbart"] = ["MBartTokenizer"] - _import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"] if is_tokenizers_available(): - _import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"] _import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"] if is_torch_available(): @@ -72,10 +70,8 @@ if TYPE_CHECKING: if is_sentencepiece_available(): from .tokenization_mbart import MBartTokenizer - from .tokenization_mbart50 import MBart50Tokenizer if is_tokenizers_available(): - from .tokenization_mbart50_fast import MBart50TokenizerFast from .tokenization_mbart_fast import MBartTokenizerFast if is_torch_available(): diff --git a/src/transformers/models/mbart50/__init__.py b/src/transformers/models/mbart50/__init__.py new file mode 100644 index 00000000000..299821063de --- /dev/null +++ b/src/transformers/models/mbart50/__init__.py @@ -0,0 +1,42 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available + + +_import_structure = {} + +if is_sentencepiece_available(): + _import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"] + +if is_tokenizers_available(): + _import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"] + + +if TYPE_CHECKING: + if is_sentencepiece_available(): + from .tokenization_mbart50 import MBart50Tokenizer + + if is_tokenizers_available(): + from .tokenization_mbart50_fast import MBart50TokenizerFast + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/mbart/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py similarity index 100% rename from src/transformers/models/mbart/tokenization_mbart50.py rename to src/transformers/models/mbart50/tokenization_mbart50.py diff --git a/src/transformers/models/mbart/tokenization_mbart50_fast.py b/src/transformers/models/mbart50/tokenization_mbart50_fast.py similarity index 100% rename from src/transformers/models/mbart/tokenization_mbart50_fast.py rename to src/transformers/models/mbart50/tokenization_mbart50_fast.py diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 132a852eda7..762c966047b 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1781,25 +1781,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): if config_tokenizer_class is None: # Third attempt. If we have not yet found the original type of the tokenizer, # we are loading we see if we can infer it from the type of the configuration file - from .models.auto.configuration_auto import CONFIG_MAPPING # tests_ignore - from .models.auto.tokenization_auto import TOKENIZER_MAPPING # tests_ignore + from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore if hasattr(config, "model_type"): - config_class = CONFIG_MAPPING.get(config.model_type) + model_type = config.model_type else: # Fallback: use pattern matching on the string. - config_class = None - for pattern, config_class_tmp in CONFIG_MAPPING.items(): + model_type = None + for pattern in TOKENIZER_MAPPING_NAMES.keys(): if pattern in str(pretrained_model_name_or_path): - config_class = config_class_tmp + model_type = pattern break - if config_class in TOKENIZER_MAPPING.keys(): - config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class] - if config_tokenizer_class is not None: - config_tokenizer_class = config_tokenizer_class.__name__ - else: - config_tokenizer_class = config_tokenizer_class_fast.__name__ + if model_type is not None: + config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES[model_type] + if config_tokenizer_class is None: + config_tokenizer_class = config_tokenizer_class_fast if config_tokenizer_class is not None: if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 84abde3e3fa..3119d807167 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -74,6 +74,7 @@ from .file_utils import ( ) from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, unwrap_model +from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES from .optimization import Adafactor, AdamW, get_scheduler from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( @@ -125,7 +126,6 @@ from .trainer_utils import ( ) from .training_args import ParallelMode, TrainingArguments from .utils import logging -from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES _is_torch_generator_available = False diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 14e2d74a21a..2662bc4823d 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -191,7 +191,7 @@ class LxmertTokenizerFast: requires_backends(cls, ["tokenizers"]) -class MBart50TokenizerFast: +class MBartTokenizerFast: def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) @@ -200,7 +200,7 @@ class MBart50TokenizerFast: requires_backends(cls, ["tokenizers"]) -class MBartTokenizerFast: +class MBart50TokenizerFast: def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py deleted file mode 100644 index 309ba38449e..00000000000 --- a/src/transformers/utils/modeling_auto_mapping.py +++ /dev/null @@ -1,374 +0,0 @@ -# THIS FILE HAS BEEN AUTOGENERATED. To update: -# 1. modify: models/auto/modeling_auto.py -# 2. run: python utils/class_mapping_update.py -from collections import OrderedDict - - -MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForQuestionAnswering"), - ("CanineConfig", "CanineForQuestionAnswering"), - ("RoFormerConfig", "RoFormerForQuestionAnswering"), - ("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"), - ("BigBirdConfig", "BigBirdForQuestionAnswering"), - ("ConvBertConfig", "ConvBertForQuestionAnswering"), - ("LEDConfig", "LEDForQuestionAnswering"), - ("DistilBertConfig", "DistilBertForQuestionAnswering"), - ("AlbertConfig", "AlbertForQuestionAnswering"), - ("CamembertConfig", "CamembertForQuestionAnswering"), - ("BartConfig", "BartForQuestionAnswering"), - ("MBartConfig", "MBartForQuestionAnswering"), - ("LongformerConfig", "LongformerForQuestionAnswering"), - ("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"), - ("RobertaConfig", "RobertaForQuestionAnswering"), - ("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"), - ("BertConfig", "BertForQuestionAnswering"), - ("XLNetConfig", "XLNetForQuestionAnsweringSimple"), - ("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"), - ("MegatronBertConfig", "MegatronBertForQuestionAnswering"), - ("MobileBertConfig", "MobileBertForQuestionAnswering"), - ("XLMConfig", "XLMForQuestionAnsweringSimple"), - ("ElectraConfig", "ElectraForQuestionAnswering"), - ("ReformerConfig", "ReformerForQuestionAnswering"), - ("FunnelConfig", "FunnelForQuestionAnswering"), - ("LxmertConfig", "LxmertForQuestionAnswering"), - ("MPNetConfig", "MPNetForQuestionAnswering"), - ("DebertaConfig", "DebertaForQuestionAnswering"), - ("DebertaV2Config", "DebertaV2ForQuestionAnswering"), - ("IBertConfig", "IBertForQuestionAnswering"), - ] -) - - -MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForCausalLM"), - ("RoFormerConfig", "RoFormerForCausalLM"), - ("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"), - ("GPTNeoConfig", "GPTNeoForCausalLM"), - ("BigBirdConfig", "BigBirdForCausalLM"), - ("CamembertConfig", "CamembertForCausalLM"), - ("XLMRobertaConfig", "XLMRobertaForCausalLM"), - ("RobertaConfig", "RobertaForCausalLM"), - ("BertConfig", "BertLMHeadModel"), - ("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"), - ("GPT2Config", "GPT2LMHeadModel"), - ("TransfoXLConfig", "TransfoXLLMHeadModel"), - ("XLNetConfig", "XLNetLMHeadModel"), - ("XLMConfig", "XLMWithLMHeadModel"), - ("CTRLConfig", "CTRLLMHeadModel"), - ("ReformerConfig", "ReformerModelWithLMHead"), - ("BertGenerationConfig", "BertGenerationDecoder"), - ("XLMProphetNetConfig", "XLMProphetNetForCausalLM"), - ("ProphetNetConfig", "ProphetNetForCausalLM"), - ("BartConfig", "BartForCausalLM"), - ("MBartConfig", "MBartForCausalLM"), - ("PegasusConfig", "PegasusForCausalLM"), - ("MarianConfig", "MarianForCausalLM"), - ("BlenderbotConfig", "BlenderbotForCausalLM"), - ("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"), - ("MegatronBertConfig", "MegatronBertForCausalLM"), - ] -) - - -MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - ("ViTConfig", "ViTForImageClassification"), - ("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"), - ("BeitConfig", "BeitForImageClassification"), - ] -) - - -MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForMaskedLM"), - ("RoFormerConfig", "RoFormerForMaskedLM"), - ("BigBirdConfig", "BigBirdForMaskedLM"), - ("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"), - ("ConvBertConfig", "ConvBertForMaskedLM"), - ("LayoutLMConfig", "LayoutLMForMaskedLM"), - ("DistilBertConfig", "DistilBertForMaskedLM"), - ("AlbertConfig", "AlbertForMaskedLM"), - ("BartConfig", "BartForConditionalGeneration"), - ("MBartConfig", "MBartForConditionalGeneration"), - ("CamembertConfig", "CamembertForMaskedLM"), - ("XLMRobertaConfig", "XLMRobertaForMaskedLM"), - ("LongformerConfig", "LongformerForMaskedLM"), - ("RobertaConfig", "RobertaForMaskedLM"), - ("SqueezeBertConfig", "SqueezeBertForMaskedLM"), - ("BertConfig", "BertForMaskedLM"), - ("MegatronBertConfig", "MegatronBertForMaskedLM"), - ("MobileBertConfig", "MobileBertForMaskedLM"), - ("FlaubertConfig", "FlaubertWithLMHeadModel"), - ("XLMConfig", "XLMWithLMHeadModel"), - ("ElectraConfig", "ElectraForMaskedLM"), - ("ReformerConfig", "ReformerForMaskedLM"), - ("FunnelConfig", "FunnelForMaskedLM"), - ("MPNetConfig", "MPNetForMaskedLM"), - ("TapasConfig", "TapasForMaskedLM"), - ("DebertaConfig", "DebertaForMaskedLM"), - ("DebertaV2Config", "DebertaV2ForMaskedLM"), - ("IBertConfig", "IBertForMaskedLM"), - ] -) - - -MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForMultipleChoice"), - ("CanineConfig", "CanineForMultipleChoice"), - ("RoFormerConfig", "RoFormerForMultipleChoice"), - ("BigBirdConfig", "BigBirdForMultipleChoice"), - ("ConvBertConfig", "ConvBertForMultipleChoice"), - ("CamembertConfig", "CamembertForMultipleChoice"), - ("ElectraConfig", "ElectraForMultipleChoice"), - ("XLMRobertaConfig", "XLMRobertaForMultipleChoice"), - ("LongformerConfig", "LongformerForMultipleChoice"), - ("RobertaConfig", "RobertaForMultipleChoice"), - ("SqueezeBertConfig", "SqueezeBertForMultipleChoice"), - ("BertConfig", "BertForMultipleChoice"), - ("DistilBertConfig", "DistilBertForMultipleChoice"), - ("MegatronBertConfig", "MegatronBertForMultipleChoice"), - ("MobileBertConfig", "MobileBertForMultipleChoice"), - ("XLNetConfig", "XLNetForMultipleChoice"), - ("AlbertConfig", "AlbertForMultipleChoice"), - ("XLMConfig", "XLMForMultipleChoice"), - ("FlaubertConfig", "FlaubertForMultipleChoice"), - ("FunnelConfig", "FunnelForMultipleChoice"), - ("MPNetConfig", "MPNetForMultipleChoice"), - ("IBertConfig", "IBertForMultipleChoice"), - ] -) - - -MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( - [ - ("BertConfig", "BertForNextSentencePrediction"), - ("MegatronBertConfig", "MegatronBertForNextSentencePrediction"), - ("MobileBertConfig", "MobileBertForNextSentencePrediction"), - ] -) - - -MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( - [ - ("DetrConfig", "DetrForObjectDetection"), - ] -) - - -MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( - [ - ("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"), - ("M2M100Config", "M2M100ForConditionalGeneration"), - ("LEDConfig", "LEDForConditionalGeneration"), - ("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"), - ("MT5Config", "MT5ForConditionalGeneration"), - ("T5Config", "T5ForConditionalGeneration"), - ("PegasusConfig", "PegasusForConditionalGeneration"), - ("MarianConfig", "MarianMTModel"), - ("MBartConfig", "MBartForConditionalGeneration"), - ("BlenderbotConfig", "BlenderbotForConditionalGeneration"), - ("BartConfig", "BartForConditionalGeneration"), - ("FSMTConfig", "FSMTForConditionalGeneration"), - ("EncoderDecoderConfig", "EncoderDecoderModel"), - ("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"), - ("ProphetNetConfig", "ProphetNetForConditionalGeneration"), - ] -) - - -MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForSequenceClassification"), - ("CanineConfig", "CanineForSequenceClassification"), - ("RoFormerConfig", "RoFormerForSequenceClassification"), - ("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"), - ("BigBirdConfig", "BigBirdForSequenceClassification"), - ("ConvBertConfig", "ConvBertForSequenceClassification"), - ("LEDConfig", "LEDForSequenceClassification"), - ("DistilBertConfig", "DistilBertForSequenceClassification"), - ("AlbertConfig", "AlbertForSequenceClassification"), - ("CamembertConfig", "CamembertForSequenceClassification"), - ("XLMRobertaConfig", "XLMRobertaForSequenceClassification"), - ("MBartConfig", "MBartForSequenceClassification"), - ("BartConfig", "BartForSequenceClassification"), - ("LongformerConfig", "LongformerForSequenceClassification"), - ("RobertaConfig", "RobertaForSequenceClassification"), - ("SqueezeBertConfig", "SqueezeBertForSequenceClassification"), - ("LayoutLMConfig", "LayoutLMForSequenceClassification"), - ("BertConfig", "BertForSequenceClassification"), - ("XLNetConfig", "XLNetForSequenceClassification"), - ("MegatronBertConfig", "MegatronBertForSequenceClassification"), - ("MobileBertConfig", "MobileBertForSequenceClassification"), - ("FlaubertConfig", "FlaubertForSequenceClassification"), - ("XLMConfig", "XLMForSequenceClassification"), - ("ElectraConfig", "ElectraForSequenceClassification"), - ("FunnelConfig", "FunnelForSequenceClassification"), - ("DebertaConfig", "DebertaForSequenceClassification"), - ("DebertaV2Config", "DebertaV2ForSequenceClassification"), - ("GPT2Config", "GPT2ForSequenceClassification"), - ("GPTNeoConfig", "GPTNeoForSequenceClassification"), - ("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"), - ("ReformerConfig", "ReformerForSequenceClassification"), - ("CTRLConfig", "CTRLForSequenceClassification"), - ("TransfoXLConfig", "TransfoXLForSequenceClassification"), - ("MPNetConfig", "MPNetForSequenceClassification"), - ("TapasConfig", "TapasForSequenceClassification"), - ("IBertConfig", "IBertForSequenceClassification"), - ] -) - - -MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( - [ - ("TapasConfig", "TapasForQuestionAnswering"), - ] -) - - -MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForTokenClassification"), - ("CanineConfig", "CanineForTokenClassification"), - ("RoFormerConfig", "RoFormerForTokenClassification"), - ("BigBirdConfig", "BigBirdForTokenClassification"), - ("ConvBertConfig", "ConvBertForTokenClassification"), - ("LayoutLMConfig", "LayoutLMForTokenClassification"), - ("DistilBertConfig", "DistilBertForTokenClassification"), - ("CamembertConfig", "CamembertForTokenClassification"), - ("FlaubertConfig", "FlaubertForTokenClassification"), - ("XLMConfig", "XLMForTokenClassification"), - ("XLMRobertaConfig", "XLMRobertaForTokenClassification"), - ("LongformerConfig", "LongformerForTokenClassification"), - ("RobertaConfig", "RobertaForTokenClassification"), - ("SqueezeBertConfig", "SqueezeBertForTokenClassification"), - ("BertConfig", "BertForTokenClassification"), - ("MegatronBertConfig", "MegatronBertForTokenClassification"), - ("MobileBertConfig", "MobileBertForTokenClassification"), - ("XLNetConfig", "XLNetForTokenClassification"), - ("AlbertConfig", "AlbertForTokenClassification"), - ("ElectraConfig", "ElectraForTokenClassification"), - ("FunnelConfig", "FunnelForTokenClassification"), - ("MPNetConfig", "MPNetForTokenClassification"), - ("DebertaConfig", "DebertaForTokenClassification"), - ("DebertaV2Config", "DebertaV2ForTokenClassification"), - ("IBertConfig", "IBertForTokenClassification"), - ] -) - - -MODEL_MAPPING_NAMES = OrderedDict( - [ - ("BeitConfig", "BeitModel"), - ("RemBertConfig", "RemBertModel"), - ("VisualBertConfig", "VisualBertModel"), - ("CanineConfig", "CanineModel"), - ("RoFormerConfig", "RoFormerModel"), - ("CLIPConfig", "CLIPModel"), - ("BigBirdPegasusConfig", "BigBirdPegasusModel"), - ("DeiTConfig", "DeiTModel"), - ("LukeConfig", "LukeModel"), - ("DetrConfig", "DetrModel"), - ("GPTNeoConfig", "GPTNeoModel"), - ("BigBirdConfig", "BigBirdModel"), - ("Speech2TextConfig", "Speech2TextModel"), - ("ViTConfig", "ViTModel"), - ("Wav2Vec2Config", "Wav2Vec2Model"), - ("HubertConfig", "HubertModel"), - ("M2M100Config", "M2M100Model"), - ("ConvBertConfig", "ConvBertModel"), - ("LEDConfig", "LEDModel"), - ("BlenderbotSmallConfig", "BlenderbotSmallModel"), - ("RetriBertConfig", "RetriBertModel"), - ("MT5Config", "MT5Model"), - ("T5Config", "T5Model"), - ("PegasusConfig", "PegasusModel"), - ("MarianConfig", "MarianModel"), - ("MBartConfig", "MBartModel"), - ("BlenderbotConfig", "BlenderbotModel"), - ("DistilBertConfig", "DistilBertModel"), - ("AlbertConfig", "AlbertModel"), - ("CamembertConfig", "CamembertModel"), - ("XLMRobertaConfig", "XLMRobertaModel"), - ("BartConfig", "BartModel"), - ("LongformerConfig", "LongformerModel"), - ("RobertaConfig", "RobertaModel"), - ("LayoutLMConfig", "LayoutLMModel"), - ("SqueezeBertConfig", "SqueezeBertModel"), - ("BertConfig", "BertModel"), - ("OpenAIGPTConfig", "OpenAIGPTModel"), - ("GPT2Config", "GPT2Model"), - ("MegatronBertConfig", "MegatronBertModel"), - ("MobileBertConfig", "MobileBertModel"), - ("TransfoXLConfig", "TransfoXLModel"), - ("XLNetConfig", "XLNetModel"), - ("FlaubertConfig", "FlaubertModel"), - ("FSMTConfig", "FSMTModel"), - ("XLMConfig", "XLMModel"), - ("CTRLConfig", "CTRLModel"), - ("ElectraConfig", "ElectraModel"), - ("ReformerConfig", "ReformerModel"), - ("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"), - ("LxmertConfig", "LxmertModel"), - ("BertGenerationConfig", "BertGenerationEncoder"), - ("DebertaConfig", "DebertaModel"), - ("DebertaV2Config", "DebertaV2Model"), - ("DPRConfig", "DPRQuestionEncoder"), - ("XLMProphetNetConfig", "XLMProphetNetModel"), - ("ProphetNetConfig", "ProphetNetModel"), - ("MPNetConfig", "MPNetModel"), - ("TapasConfig", "TapasModel"), - ("IBertConfig", "IBertModel"), - ] -) - - -MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( - [ - ("RemBertConfig", "RemBertForMaskedLM"), - ("RoFormerConfig", "RoFormerForMaskedLM"), - ("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"), - ("GPTNeoConfig", "GPTNeoForCausalLM"), - ("BigBirdConfig", "BigBirdForMaskedLM"), - ("Speech2TextConfig", "Speech2TextForConditionalGeneration"), - ("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"), - ("M2M100Config", "M2M100ForConditionalGeneration"), - ("ConvBertConfig", "ConvBertForMaskedLM"), - ("LEDConfig", "LEDForConditionalGeneration"), - ("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"), - ("LayoutLMConfig", "LayoutLMForMaskedLM"), - ("T5Config", "T5ForConditionalGeneration"), - ("DistilBertConfig", "DistilBertForMaskedLM"), - ("AlbertConfig", "AlbertForMaskedLM"), - ("CamembertConfig", "CamembertForMaskedLM"), - ("XLMRobertaConfig", "XLMRobertaForMaskedLM"), - ("MarianConfig", "MarianMTModel"), - ("FSMTConfig", "FSMTForConditionalGeneration"), - ("BartConfig", "BartForConditionalGeneration"), - ("LongformerConfig", "LongformerForMaskedLM"), - ("RobertaConfig", "RobertaForMaskedLM"), - ("SqueezeBertConfig", "SqueezeBertForMaskedLM"), - ("BertConfig", "BertForMaskedLM"), - ("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"), - ("GPT2Config", "GPT2LMHeadModel"), - ("MegatronBertConfig", "MegatronBertForCausalLM"), - ("MobileBertConfig", "MobileBertForMaskedLM"), - ("TransfoXLConfig", "TransfoXLLMHeadModel"), - ("XLNetConfig", "XLNetLMHeadModel"), - ("FlaubertConfig", "FlaubertWithLMHeadModel"), - ("XLMConfig", "XLMWithLMHeadModel"), - ("CTRLConfig", "CTRLLMHeadModel"), - ("ElectraConfig", "ElectraForMaskedLM"), - ("EncoderDecoderConfig", "EncoderDecoderModel"), - ("ReformerConfig", "ReformerModelWithLMHead"), - ("FunnelConfig", "FunnelForMaskedLM"), - ("MPNetConfig", "MPNetForMaskedLM"), - ("TapasConfig", "TapasForMaskedLM"), - ("DebertaConfig", "DebertaForMaskedLM"), - ("DebertaV2Config", "DebertaV2ForMaskedLM"), - ("IBertConfig", "IBertForMaskedLM"), - ] -) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 5e1d866c788..c8ef8a12381 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -17,7 +17,7 @@ {% if cookiecutter.is_encoder_decoder_model == "False" %} import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy as np import tensorflow as tf @@ -1484,7 +1484,7 @@ from ...file_utils import ( ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -2162,7 +2162,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py index 764a2586ef6..cfdb3484ced 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py @@ -172,17 +172,12 @@ # To replace in: "src/transformers/models/auto/configuration_auto.py" # Below: "# Add configs here" # Replace with: - ("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Config"), # End. # Below: "# Add archive maps here" # Replace with: - {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, -# End. - -# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig", -# Replace with: -from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP"), # End. # Below: "# Add full (and cased) model names here" @@ -193,75 +188,47 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca # To replace in: "src/transformers/models/auto/modeling_auto.py" if generating PyTorch -# Below: "from .configuration_auto import (" -# Replace with: - {{cookiecutter.camelcase_modelname}}Config, -# End. - -# Below: "# Add modeling imports here" -# Replace with: -{% if cookiecutter.is_encoder_decoder_model == "False" -%} -from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import ( - {{cookiecutter.camelcase_modelname}}ForMaskedLM, - {{cookiecutter.camelcase_modelname}}ForCausalLM, - {{cookiecutter.camelcase_modelname}}ForMultipleChoice, - {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, - {{cookiecutter.camelcase_modelname}}ForSequenceClassification, - {{cookiecutter.camelcase_modelname}}ForTokenClassification, - {{cookiecutter.camelcase_modelname}}Model, -) -{% else -%} -from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import ( - {{cookiecutter.camelcase_modelname}}ForConditionalGeneration, - {{cookiecutter.camelcase_modelname}}ForCausalLM, - {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, - {{cookiecutter.camelcase_modelname}}ForSequenceClassification, - {{cookiecutter.camelcase_modelname}}Model, -) -{% endif -%} -# End. - # Below: "# Base model mapping" # Replace with: - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Model), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Model"), # End. # Below: "# Model with LM heads mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"), {% else %} - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"), {% endif -%} # End. # Below: "# Model for Causal LM mapping" # Replace with: - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForCausalLM"), # End. # Below: "# Model for Masked LM mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"), {% else -%} {% endif -%} # End. # Below: "# Model for Sequence Classification mapping" # Replace with: - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForSequenceClassification), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification"), # End. # Below: "# Model for Question Answering mapping" # Replace with: - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"), # End. # Below: "# Model for Token Classification mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForTokenClassification), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForTokenClassification"), {% else -%} {% endif -%} # End. @@ -269,7 +236,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo # Below: "# Model for Multiple Choice mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMultipleChoice), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMultipleChoice"), {% else -%} {% endif -%} # End. @@ -278,54 +245,29 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} {% else %} - ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration), + ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"), {% endif -%} # End. # To replace in: "src/transformers/models/auto/modeling_tf_auto.py" if generating TensorFlow -# Below: "from .configuration_auto import (" -# Replace with: - {{cookiecutter.camelcase_modelname}}Config, -# End. - -# Below: "# Add modeling imports here" -# Replace with: -{% if cookiecutter.is_encoder_decoder_model == "False" -%} -from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import ( - TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, - TF{{cookiecutter.camelcase_modelname}}ForCausalLM, - TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, - TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, - TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification, - TF{{cookiecutter.camelcase_modelname}}ForTokenClassification, - TF{{cookiecutter.camelcase_modelname}}Model, -) -{% else -%} -from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import ( - TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, - TF{{cookiecutter.camelcase_modelname}}Model, -) -{% endif -%} -# End. - # Below: "# Base model mapping" # Replace with: - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}Model), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}Model"), # End. # Below: "# Model with LM heads mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"), {% else %} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"), {% endif -%} # End. # Below: "# Model for Causal LM mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForCausalLM), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForCausalLM"), {% else -%} {% endif -%} # End. @@ -333,7 +275,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase # Below: "# Model for Masked LM mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"), {% else -%} {% endif -%} # End. @@ -341,7 +283,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase # Below: "# Model for Sequence Classification mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification"), {% else -%} {% endif -%} # End. @@ -349,7 +291,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase # Below: "# Model for Question Answering mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"), {% else -%} {% endif -%} # End. @@ -357,7 +299,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase # Below: "# Model for Token Classification mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForTokenClassification), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForTokenClassification"), {% else -%} {% endif -%} # End. @@ -365,7 +307,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase # Below: "# Model for Multiple Choice mapping" # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice"), {% else -%} {% endif -%} # End. @@ -374,7 +316,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" -%} {% else %} - ({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration), + ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"), {% endif -%} # End. diff --git a/tests/test_pipelines_translation.py b/tests/test_pipelines_translation.py index 222f7b4ed58..4456410d6f6 100644 --- a/tests/test_pipelines_translation.py +++ b/tests/test_pipelines_translation.py @@ -23,7 +23,8 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin if is_torch_available(): - from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration + from transformers.models.mbart import MBartForConditionalGeneration + from transformers.models.mbart50 import MBart50TokenizerFast class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): diff --git a/utils/check_repo.py b/utils/check_repo.py index 47cf8fd2175..f7425d36fe0 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -306,17 +306,17 @@ def get_all_auto_configured_models(): result = set() # To avoid duplicates we concatenate all model classes in a set. if is_torch_available(): for attr_name in dir(transformers.models.auto.modeling_auto): - if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"): + if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"): result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name))) if is_tf_available(): for attr_name in dir(transformers.models.auto.modeling_tf_auto): - if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"): + if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"): result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name))) if is_flax_available(): for attr_name in dir(transformers.models.auto.modeling_flax_auto): - if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"): + if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"): result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name))) - return [cls.__name__ for cls in result] + return [cls for cls in result] def ignore_unautoclassed(model_name): diff --git a/utils/check_table.py b/utils/check_table.py index 9151040fc93..4cb105b1193 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -87,12 +87,13 @@ def get_model_table_from_auto_modules(): transformers = spec.loader.load_module() # Dictionary model names to config. + config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES model_name_to_config = { - name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items() - } - model_name_to_prefix = { - name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items() + name: config_maping_names[code] + for code, name in transformers.MODEL_NAMES_MAPPING.items() + if code in config_maping_names } + model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()} # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax. slow_tokenizers = collections.defaultdict(bool) diff --git a/utils/class_mapping_update.py b/utils/class_mapping_update.py deleted file mode 100644 index 71f02dcef44..00000000000 --- a/utils/class_mapping_update.py +++ /dev/null @@ -1,106 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# this script remaps classes to class strings so that it's quick to load such maps and not require -# loading all possible modeling files -# -# it can be extended to auto-generate other dicts that are needed at runtime - - -import os -import sys -from os.path import abspath, dirname, join - - -git_repo_path = abspath(join(dirname(dirname(__file__)), "src")) -sys.path.insert(1, git_repo_path) - -src = "src/transformers/models/auto/modeling_auto.py" -dst = "src/transformers/utils/modeling_auto_mapping.py" - - -if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst): - # speed things up by only running this script if the src is newer than dst - sys.exit(0) - -# only load if needed -from transformers.models.auto.modeling_auto import ( # noqa - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - MODEL_FOR_OBJECT_DETECTION_MAPPING, - MODEL_FOR_PRETRAINING_MAPPING, - MODEL_FOR_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - MODEL_MAPPING, - MODEL_WITH_LM_HEAD_MAPPING, -) - - -# Those constants don't have a name attribute, so we need to define it manually -mappings = { - "MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING, - "MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING, - "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, - "MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING, - "MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - "MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING, - "MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING, - "MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING, - "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, - "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - "MODEL_MAPPING": MODEL_MAPPING, - "MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING, -} - - -def get_name(value): - if isinstance(value, tuple): - return tuple(get_name(o) for o in value) - return value.__name__ - - -content = [ - "# THIS FILE HAS BEEN AUTOGENERATED. To update:", - "# 1. modify: models/auto/modeling_auto.py", - "# 2. run: python utils/class_mapping_update.py", - "from collections import OrderedDict", - "", -] - -for name, mapping in mappings.items(): - entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()]) - - content += [ - "", - f"{name}_NAMES = OrderedDict(", - " [", - entries, - " ]", - ")", - "", - ] - -print(f"Updating {dst}") -with open(dst, "w", encoding="utf-8", newline="\n") as f: - f.write("\n".join(content))