diff --git a/.circleci/config.yml b/.circleci/config.yml index 0e8f394e9e9..6df8ad8e89f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -857,6 +857,7 @@ jobs: - run: black --check --preview examples tests src utils - run: isort --check-only examples tests src utils - run: python utils/custom_init_isort.py --check_only + - run: python utils/sort_auto_mappings.py --check_only - run: flake8 examples tests src utils - run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source diff --git a/Makefile b/Makefile index c9226bb0d8f..f0abc15de8e 100644 --- a/Makefile +++ b/Makefile @@ -48,6 +48,7 @@ quality: black --check --preview $(check_dirs) isort --check-only $(check_dirs) python utils/custom_init_isort.py --check_only + python utils/sort_auto_mappings.py --check_only flake8 $(check_dirs) doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source @@ -55,6 +56,7 @@ quality: extra_style_checks: python utils/custom_init_isort.py + python utils/sort_auto_mappings.py doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source # this target runs checks on all files and potentially modifies some of them diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 58249f6e164..e0b6674f69b 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow. | Swin | ❌ | ❌ | ✅ | ❌ | ❌ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | -| TAPEX | ✅ | ✅ | ✅ | ✅ | ✅ | | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ | | UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index df39f1b97c5..4ae35a96aeb 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -74,7 +74,6 @@ Ready-made configurations include the following architectures: - RoBERTa - RoFormer - T5 -- TAPEX - ViT - XLM-RoBERTa - XLM-RoBERTa-XL diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 6c2297d9f84..b7d8f66c339 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -560,10 +560,17 @@ class _LazyAutoMapping(OrderedDict): if key in self._extra_content: return self._extra_content[key] model_type = self._reverse_config_mapping[key.__name__] - if model_type not in self._model_mapping: - raise KeyError(key) - model_name = self._model_mapping[model_type] - return self._load_attr_from_module(model_type, model_name) + if model_type in self._model_mapping: + model_name = self._model_mapping[model_type] + return self._load_attr_from_module(model_type, model_name) + + # Maybe there was several model types associated with this config. + model_types = [k for k, v in self._config_mapping.items() if v == key.__name__] + for mtype in model_types: + if mtype in self._model_mapping: + model_name = self._model_mapping[mtype] + return self._load_attr_from_module(mtype, model_name) + raise KeyError(key) def _load_attr_from_module(self, model_type, attr): module_name = model_type_to_module_name(model_type) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index dd2a4b491ab..baa4f1aeb58 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -29,339 +29,338 @@ logger = logging.get_logger(__name__) CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here - ("yolos", "YolosConfig"), - ("tapex", "BartConfig"), - ("dpt", "DPTConfig"), - ("decision_transformer", "DecisionTransformerConfig"), - ("glpn", "GLPNConfig"), - ("maskformer", "MaskFormerConfig"), - ("decision_transformer", "DecisionTransformerConfig"), - ("poolformer", "PoolFormerConfig"), - ("convnext", "ConvNextConfig"), - ("van", "VanConfig"), - ("resnet", "ResNetConfig"), - ("regnet", "RegNetConfig"), - ("yoso", "YosoConfig"), - ("swin", "SwinConfig"), - ("vilt", "ViltConfig"), - ("vit_mae", "ViTMAEConfig"), - ("realm", "RealmConfig"), - ("nystromformer", "NystromformerConfig"), - ("xglm", "XGLMConfig"), - ("imagegpt", "ImageGPTConfig"), - ("qdqbert", "QDQBertConfig"), - ("vision-encoder-decoder", "VisionEncoderDecoderConfig"), - ("trocr", "TrOCRConfig"), - ("fnet", "FNetConfig"), - ("segformer", "SegformerConfig"), - ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"), - ("perceiver", "PerceiverConfig"), - ("gptj", "GPTJConfig"), - ("layoutlmv2", "LayoutLMv2Config"), - ("plbart", "PLBartConfig"), - ("beit", "BeitConfig"), - ("data2vec-vision", "Data2VecVisionConfig"), - ("rembert", "RemBertConfig"), - ("visual_bert", "VisualBertConfig"), - ("canine", "CanineConfig"), - ("roformer", "RoFormerConfig"), - ("clip", "CLIPConfig"), - ("flava", "FlavaConfig"), - ("bigbird_pegasus", "BigBirdPegasusConfig"), - ("deit", "DeiTConfig"), - ("luke", "LukeConfig"), - ("detr", "DetrConfig"), - ("gpt_neo", "GPTNeoConfig"), - ("big_bird", "BigBirdConfig"), - ("speech_to_text_2", "Speech2Text2Config"), - ("speech_to_text", "Speech2TextConfig"), - ("vit", "ViTConfig"), - ("wav2vec2", "Wav2Vec2Config"), - ("m2m_100", "M2M100Config"), - ("convbert", "ConvBertConfig"), - ("led", "LEDConfig"), - ("blenderbot-small", "BlenderbotSmallConfig"), - ("retribert", "RetriBertConfig"), - ("ibert", "IBertConfig"), - ("mt5", "MT5Config"), - ("t5", "T5Config"), - ("mobilebert", "MobileBertConfig"), - ("distilbert", "DistilBertConfig"), ("albert", "AlbertConfig"), - ("bert-generation", "BertGenerationConfig"), - ("camembert", "CamembertConfig"), - ("xlm-roberta-xl", "XLMRobertaXLConfig"), - ("xlm-roberta", "XLMRobertaConfig"), - ("pegasus", "PegasusConfig"), - ("marian", "MarianConfig"), - ("mbart", "MBartConfig"), - ("megatron-bert", "MegatronBertConfig"), - ("mpnet", "MPNetConfig"), ("bart", "BartConfig"), - ("opt", "OPTConfig"), - ("blenderbot", "BlenderbotConfig"), - ("reformer", "ReformerConfig"), - ("longformer", "LongformerConfig"), - ("roberta", "RobertaConfig"), - ("deberta-v2", "DebertaV2Config"), - ("deberta", "DebertaConfig"), - ("flaubert", "FlaubertConfig"), - ("fsmt", "FSMTConfig"), - ("squeezebert", "SqueezeBertConfig"), - ("hubert", "HubertConfig"), + ("beit", "BeitConfig"), ("bert", "BertConfig"), - ("openai-gpt", "OpenAIGPTConfig"), - ("gpt2", "GPT2Config"), - ("transfo-xl", "TransfoXLConfig"), - ("xlnet", "XLNetConfig"), - ("xlm-prophetnet", "XLMProphetNetConfig"), - ("prophetnet", "ProphetNetConfig"), - ("xlm", "XLMConfig"), + ("bert-generation", "BertGenerationConfig"), + ("big_bird", "BigBirdConfig"), + ("bigbird_pegasus", "BigBirdPegasusConfig"), + ("blenderbot", "BlenderbotConfig"), + ("blenderbot-small", "BlenderbotSmallConfig"), + ("camembert", "CamembertConfig"), + ("canine", "CanineConfig"), + ("clip", "CLIPConfig"), + ("convbert", "ConvBertConfig"), + ("convnext", "ConvNextConfig"), ("ctrl", "CTRLConfig"), - ("electra", "ElectraConfig"), - ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"), - ("encoder-decoder", "EncoderDecoderConfig"), - ("funnel", "FunnelConfig"), - ("lxmert", "LxmertConfig"), - ("dpr", "DPRConfig"), - ("layoutlm", "LayoutLMConfig"), - ("rag", "RagConfig"), - ("tapas", "TapasConfig"), - ("splinter", "SplinterConfig"), - ("sew-d", "SEWDConfig"), - ("sew", "SEWConfig"), - ("unispeech-sat", "UniSpeechSatConfig"), - ("unispeech", "UniSpeechConfig"), - ("wavlm", "WavLMConfig"), ("data2vec-audio", "Data2VecAudioConfig"), ("data2vec-text", "Data2VecTextConfig"), + ("data2vec-vision", "Data2VecVisionConfig"), + ("deberta", "DebertaConfig"), + ("deberta-v2", "DebertaV2Config"), + ("decision_transformer", "DecisionTransformerConfig"), + ("decision_transformer", "DecisionTransformerConfig"), + ("deit", "DeiTConfig"), + ("detr", "DetrConfig"), + ("distilbert", "DistilBertConfig"), + ("dpr", "DPRConfig"), + ("dpt", "DPTConfig"), + ("electra", "ElectraConfig"), + ("encoder-decoder", "EncoderDecoderConfig"), + ("flaubert", "FlaubertConfig"), + ("flava", "FlavaConfig"), + ("fnet", "FNetConfig"), + ("fsmt", "FSMTConfig"), + ("funnel", "FunnelConfig"), + ("glpn", "GLPNConfig"), + ("gpt2", "GPT2Config"), + ("gpt_neo", "GPTNeoConfig"), + ("gptj", "GPTJConfig"), + ("hubert", "HubertConfig"), + ("ibert", "IBertConfig"), + ("imagegpt", "ImageGPTConfig"), + ("layoutlm", "LayoutLMConfig"), + ("layoutlmv2", "LayoutLMv2Config"), + ("led", "LEDConfig"), + ("longformer", "LongformerConfig"), + ("luke", "LukeConfig"), + ("lxmert", "LxmertConfig"), + ("m2m_100", "M2M100Config"), + ("marian", "MarianConfig"), + ("maskformer", "MaskFormerConfig"), + ("mbart", "MBartConfig"), + ("megatron-bert", "MegatronBertConfig"), + ("mobilebert", "MobileBertConfig"), + ("mpnet", "MPNetConfig"), + ("mt5", "MT5Config"), + ("nystromformer", "NystromformerConfig"), + ("openai-gpt", "OpenAIGPTConfig"), + ("opt", "OPTConfig"), + ("pegasus", "PegasusConfig"), + ("perceiver", "PerceiverConfig"), + ("plbart", "PLBartConfig"), + ("poolformer", "PoolFormerConfig"), + ("prophetnet", "ProphetNetConfig"), + ("qdqbert", "QDQBertConfig"), + ("rag", "RagConfig"), + ("realm", "RealmConfig"), + ("reformer", "ReformerConfig"), + ("regnet", "RegNetConfig"), + ("rembert", "RemBertConfig"), + ("resnet", "ResNetConfig"), + ("retribert", "RetriBertConfig"), + ("roberta", "RobertaConfig"), + ("roformer", "RoFormerConfig"), + ("segformer", "SegformerConfig"), + ("sew", "SEWConfig"), + ("sew-d", "SEWDConfig"), + ("speech-encoder-decoder", "SpeechEncoderDecoderConfig"), + ("speech_to_text", "Speech2TextConfig"), + ("speech_to_text_2", "Speech2Text2Config"), + ("splinter", "SplinterConfig"), + ("squeezebert", "SqueezeBertConfig"), + ("swin", "SwinConfig"), + ("t5", "T5Config"), + ("tapas", "TapasConfig"), + ("transfo-xl", "TransfoXLConfig"), + ("trocr", "TrOCRConfig"), + ("unispeech", "UniSpeechConfig"), + ("unispeech-sat", "UniSpeechSatConfig"), + ("van", "VanConfig"), + ("vilt", "ViltConfig"), + ("vision-encoder-decoder", "VisionEncoderDecoderConfig"), + ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"), + ("visual_bert", "VisualBertConfig"), + ("vit", "ViTConfig"), + ("vit_mae", "ViTMAEConfig"), + ("wav2vec2", "Wav2Vec2Config"), + ("wavlm", "WavLMConfig"), + ("xglm", "XGLMConfig"), + ("xlm", "XLMConfig"), + ("xlm-prophetnet", "XLMProphetNetConfig"), + ("xlm-roberta", "XLMRobertaConfig"), + ("xlm-roberta-xl", "XLMRobertaXLConfig"), + ("xlnet", "XLNetConfig"), + ("yolos", "YolosConfig"), + ("yoso", "YosoConfig"), ] ) CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here) - ("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ] ) MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here - ("yolos", "YOLOS"), - ("tapex", "TAPEX"), - ("dpt", "DPT"), - ("decision_transformer", "Decision Transformer"), - ("glpn", "GLPN"), - ("maskformer", "MaskFormer"), - ("poolformer", "PoolFormer"), - ("convnext", "ConvNext"), - ("van", "VAN"), - ("resnet", "ResNet"), - ("regnet", "RegNet"), - ("yoso", "YOSO"), - ("swin", "Swin"), - ("vilt", "ViLT"), - ("vit_mae", "ViTMAE"), - ("realm", "Realm"), - ("nystromformer", "Nystromformer"), - ("xglm", "XGLM"), - ("imagegpt", "ImageGPT"), - ("qdqbert", "QDQBert"), - ("vision-encoder-decoder", "Vision Encoder decoder"), - ("trocr", "TrOCR"), - ("fnet", "FNet"), - ("segformer", "SegFormer"), - ("vision-text-dual-encoder", "VisionTextDualEncoder"), - ("perceiver", "Perceiver"), - ("gptj", "GPT-J"), - ("beit", "BEiT"), - ("plbart", "PLBart"), - ("rembert", "RemBERT"), - ("layoutlmv2", "LayoutLMv2"), - ("visual_bert", "VisualBert"), - ("canine", "Canine"), - ("roformer", "RoFormer"), - ("clip", "CLIP"), - ("flava", "Flava"), - ("bigbird_pegasus", "BigBirdPegasus"), - ("deit", "DeiT"), - ("luke", "LUKE"), - ("detr", "DETR"), - ("gpt_neo", "GPT Neo"), - ("big_bird", "BigBird"), - ("speech_to_text_2", "Speech2Text2"), - ("speech_to_text", "Speech2Text"), - ("vit", "ViT"), - ("wav2vec2", "Wav2Vec2"), - ("m2m_100", "M2M100"), - ("convbert", "ConvBERT"), - ("led", "LED"), - ("blenderbot-small", "BlenderbotSmall"), - ("retribert", "RetriBERT"), - ("ibert", "I-BERT"), - ("t5", "T5"), - ("mobilebert", "MobileBERT"), - ("distilbert", "DistilBERT"), ("albert", "ALBERT"), - ("bert-generation", "Bert Generation"), - ("camembert", "CamemBERT"), - ("xlm-roberta", "XLM-RoBERTa"), - ("xlm-roberta-xl", "XLM-RoBERTa-XL"), - ("pegasus", "Pegasus"), - ("blenderbot", "Blenderbot"), - ("marian", "Marian"), - ("mbart", "mBART"), - ("megatron-bert", "MegatronBert"), ("bart", "BART"), - ("opt", "OPT"), - ("reformer", "Reformer"), - ("longformer", "Longformer"), - ("roberta", "RoBERTa"), - ("flaubert", "FlauBERT"), - ("fsmt", "FairSeq Machine-Translation"), - ("squeezebert", "SqueezeBERT"), - ("bert", "BERT"), - ("openai-gpt", "OpenAI GPT"), - ("gpt2", "OpenAI GPT-2"), - ("transfo-xl", "Transformer-XL"), - ("xlnet", "XLNet"), - ("xlm", "XLM"), - ("ctrl", "CTRL"), - ("electra", "ELECTRA"), - ("encoder-decoder", "Encoder decoder"), - ("speech-encoder-decoder", "Speech Encoder decoder"), - ("vision-encoder-decoder", "Vision Encoder decoder"), - ("funnel", "Funnel Transformer"), - ("lxmert", "LXMERT"), - ("deberta-v2", "DeBERTa-v2"), - ("deberta", "DeBERTa"), - ("layoutlm", "LayoutLM"), - ("dpr", "DPR"), - ("rag", "RAG"), - ("xlm-prophetnet", "XLMProphetNet"), - ("prophetnet", "ProphetNet"), - ("mt5", "mT5"), - ("mpnet", "MPNet"), - ("tapas", "TAPAS"), - ("hubert", "Hubert"), ("barthez", "BARThez"), - ("phobert", "PhoBERT"), ("bartpho", "BARTpho"), - ("cpm", "CPM"), - ("bertweet", "Bertweet"), + ("beit", "BEiT"), + ("bert", "BERT"), + ("bert-generation", "Bert Generation"), ("bert-japanese", "BertJapanese"), - ("byt5", "ByT5"), - ("mbart50", "mBART-50"), - ("splinter", "Splinter"), - ("sew-d", "SEW-D"), - ("sew", "SEW"), - ("unispeech-sat", "UniSpeechSat"), - ("unispeech", "UniSpeech"), - ("wavlm", "WavLM"), + ("bertweet", "Bertweet"), + ("big_bird", "BigBird"), + ("bigbird_pegasus", "BigBirdPegasus"), + ("blenderbot", "Blenderbot"), + ("blenderbot-small", "BlenderbotSmall"), ("bort", "BORT"), - ("dialogpt", "DialoGPT"), - ("xls_r", "XLS-R"), - ("t5v1.1", "T5v1.1"), - ("herbert", "HerBERT"), - ("wav2vec2_phoneme", "Wav2Vec2Phoneme"), - ("megatron_gpt2", "MegatronGPT2"), - ("xlsr_wav2vec2", "XLSR-Wav2Vec2"), - ("mluke", "mLUKE"), - ("layoutxlm", "LayoutXLM"), + ("byt5", "ByT5"), + ("camembert", "CamemBERT"), + ("canine", "Canine"), + ("clip", "CLIP"), + ("convbert", "ConvBERT"), + ("convnext", "ConvNext"), + ("cpm", "CPM"), + ("ctrl", "CTRL"), ("data2vec-audio", "Data2VecAudio"), ("data2vec-text", "Data2VecText"), ("data2vec-vision", "Data2VecVision"), + ("deberta", "DeBERTa"), + ("deberta-v2", "DeBERTa-v2"), + ("decision_transformer", "Decision Transformer"), + ("deit", "DeiT"), + ("detr", "DETR"), + ("dialogpt", "DialoGPT"), + ("distilbert", "DistilBERT"), ("dit", "DiT"), + ("dpr", "DPR"), + ("dpt", "DPT"), + ("electra", "ELECTRA"), + ("encoder-decoder", "Encoder decoder"), + ("flaubert", "FlauBERT"), + ("flava", "Flava"), + ("fnet", "FNet"), + ("fsmt", "FairSeq Machine-Translation"), + ("funnel", "Funnel Transformer"), + ("glpn", "GLPN"), + ("gpt2", "OpenAI GPT-2"), + ("gpt_neo", "GPT Neo"), + ("gptj", "GPT-J"), + ("herbert", "HerBERT"), + ("hubert", "Hubert"), + ("ibert", "I-BERT"), + ("imagegpt", "ImageGPT"), + ("layoutlm", "LayoutLM"), + ("layoutlmv2", "LayoutLMv2"), + ("layoutxlm", "LayoutXLM"), + ("led", "LED"), + ("longformer", "Longformer"), + ("luke", "LUKE"), + ("lxmert", "LXMERT"), + ("m2m_100", "M2M100"), + ("marian", "Marian"), + ("maskformer", "MaskFormer"), + ("mbart", "mBART"), + ("mbart50", "mBART-50"), + ("megatron-bert", "MegatronBert"), + ("megatron_gpt2", "MegatronGPT2"), + ("mluke", "mLUKE"), + ("mobilebert", "MobileBERT"), + ("mpnet", "MPNet"), + ("mt5", "mT5"), + ("nystromformer", "Nystromformer"), + ("openai-gpt", "OpenAI GPT"), + ("opt", "OPT"), + ("pegasus", "Pegasus"), + ("perceiver", "Perceiver"), + ("phobert", "PhoBERT"), + ("plbart", "PLBart"), + ("poolformer", "PoolFormer"), + ("prophetnet", "ProphetNet"), + ("qdqbert", "QDQBert"), + ("rag", "RAG"), + ("realm", "Realm"), + ("reformer", "Reformer"), + ("regnet", "RegNet"), + ("rembert", "RemBERT"), + ("resnet", "ResNet"), + ("retribert", "RetriBERT"), + ("roberta", "RoBERTa"), + ("roformer", "RoFormer"), + ("segformer", "SegFormer"), + ("sew", "SEW"), + ("sew-d", "SEW-D"), + ("speech-encoder-decoder", "Speech Encoder decoder"), + ("speech_to_text", "Speech2Text"), + ("speech_to_text_2", "Speech2Text2"), + ("splinter", "Splinter"), + ("squeezebert", "SqueezeBERT"), + ("swin", "Swin"), + ("t5", "T5"), + ("t5v1.1", "T5v1.1"), + ("tapas", "TAPAS"), + ("tapex", "TAPEX"), + ("transfo-xl", "Transformer-XL"), + ("trocr", "TrOCR"), + ("unispeech", "UniSpeech"), + ("unispeech-sat", "UniSpeechSat"), + ("van", "VAN"), + ("vilt", "ViLT"), + ("vision-encoder-decoder", "Vision Encoder decoder"), + ("vision-encoder-decoder", "Vision Encoder decoder"), + ("vision-text-dual-encoder", "VisionTextDualEncoder"), + ("visual_bert", "VisualBert"), + ("vit", "ViT"), + ("vit_mae", "ViTMAE"), + ("wav2vec2", "Wav2Vec2"), + ("wav2vec2_phoneme", "Wav2Vec2Phoneme"), + ("wavlm", "WavLM"), + ("xglm", "XGLM"), + ("xlm", "XLM"), + ("xlm-prophetnet", "XLMProphetNet"), + ("xlm-roberta", "XLM-RoBERTa"), + ("xlm-roberta-xl", "XLM-RoBERTa-XL"), + ("xlnet", "XLNet"), + ("xls_r", "XLS-R"), + ("xlsr_wav2vec2", "XLSR-Wav2Vec2"), + ("yolos", "YOLOS"), + ("yoso", "YOSO"), ] ) @@ -703,9 +702,10 @@ class AutoConfig: return config_class.from_dict(config_dict, **kwargs) else: # Fallback: use pattern matching on the string. - for pattern, config_class in CONFIG_MAPPING.items(): + # We go from longer names to shorter names to catch roberta before bert (for instance) + for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True): if pattern in str(pretrained_model_name_or_path): - return config_class.from_dict(config_dict, **kwargs) + return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs) raise ValueError( f"Unrecognized model in {pretrained_model_name_or_path}. " diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 233f4ff6c9f..64ae65fc790 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -38,30 +38,30 @@ logger = logging.get_logger(__name__) FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( [ ("beit", "BeitFeatureExtractor"), - ("detr", "DetrFeatureExtractor"), - ("deit", "DeiTFeatureExtractor"), - ("hubert", "Wav2Vec2FeatureExtractor"), - ("speech_to_text", "Speech2TextFeatureExtractor"), - ("vit", "ViTFeatureExtractor"), - ("wav2vec2", "Wav2Vec2FeatureExtractor"), - ("detr", "DetrFeatureExtractor"), - ("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("clip", "CLIPFeatureExtractor"), - ("flava", "FlavaFeatureExtractor"), - ("perceiver", "PerceiverFeatureExtractor"), - ("swin", "ViTFeatureExtractor"), - ("vit_mae", "ViTFeatureExtractor"), - ("segformer", "SegformerFeatureExtractor"), ("convnext", "ConvNextFeatureExtractor"), - ("van", "ConvNextFeatureExtractor"), - ("resnet", "ConvNextFeatureExtractor"), - ("regnet", "ConvNextFeatureExtractor"), - ("poolformer", "PoolFormerFeatureExtractor"), - ("maskformer", "MaskFormerFeatureExtractor"), ("data2vec-audio", "Wav2Vec2FeatureExtractor"), ("data2vec-vision", "BeitFeatureExtractor"), + ("deit", "DeiTFeatureExtractor"), + ("detr", "DetrFeatureExtractor"), + ("detr", "DetrFeatureExtractor"), ("dpt", "DPTFeatureExtractor"), + ("flava", "FlavaFeatureExtractor"), ("glpn", "GLPNFeatureExtractor"), + ("hubert", "Wav2Vec2FeatureExtractor"), + ("layoutlmv2", "LayoutLMv2FeatureExtractor"), + ("maskformer", "MaskFormerFeatureExtractor"), + ("perceiver", "PerceiverFeatureExtractor"), + ("poolformer", "PoolFormerFeatureExtractor"), + ("regnet", "ConvNextFeatureExtractor"), + ("resnet", "ConvNextFeatureExtractor"), + ("segformer", "SegformerFeatureExtractor"), + ("speech_to_text", "Speech2TextFeatureExtractor"), + ("swin", "ViTFeatureExtractor"), + ("van", "ConvNextFeatureExtractor"), + ("vit", "ViTFeatureExtractor"), + ("vit_mae", "ViTFeatureExtractor"), + ("wav2vec2", "Wav2Vec2FeatureExtractor"), ("yolos", "YolosFeatureExtractor"), ] ) @@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str): module_name = model_type_to_module_name(module_name) module = importlib.import_module(f".{module_name}", "transformers.models") - return getattr(module, class_name) - break + try: + return getattr(module, class_name) + except AttributeError: + continue for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items(): if getattr(extractor, "__name__", None) == class_name: diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 76809c53003..11bcee74db6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -28,257 +28,257 @@ logger = logging.get_logger(__name__) MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping - ("yolos", "YolosModel"), - ("dpt", "DPTModel"), - ("decision_transformer", "DecisionTransformerModel"), - ("glpn", "GLPNModel"), - ("maskformer", "MaskFormerModel"), - ("decision_transformer", "DecisionTransformerModel"), - ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"), - ("poolformer", "PoolFormerModel"), - ("convnext", "ConvNextModel"), - ("van", "VanModel"), - ("resnet", "ResNetModel"), - ("regnet", "RegNetModel"), - ("yoso", "YosoModel"), - ("swin", "SwinModel"), - ("vilt", "ViltModel"), - ("vit_mae", "ViTMAEModel"), - ("nystromformer", "NystromformerModel"), - ("xglm", "XGLMModel"), - ("imagegpt", "ImageGPTModel"), - ("qdqbert", "QDQBertModel"), - ("fnet", "FNetModel"), - ("segformer", "SegformerModel"), - ("vision-text-dual-encoder", "VisionTextDualEncoderModel"), - ("perceiver", "PerceiverModel"), - ("gptj", "GPTJModel"), - ("layoutlmv2", "LayoutLMv2Model"), - ("plbart", "PLBartModel"), - ("beit", "BeitModel"), - ("data2vec-vision", "Data2VecVisionModel"), - ("rembert", "RemBertModel"), - ("visual_bert", "VisualBertModel"), - ("canine", "CanineModel"), - ("roformer", "RoFormerModel"), - ("clip", "CLIPModel"), - ("flava", "FlavaModel"), - ("bigbird_pegasus", "BigBirdPegasusModel"), - ("deit", "DeiTModel"), - ("luke", "LukeModel"), - ("detr", "DetrModel"), - ("gpt_neo", "GPTNeoModel"), - ("big_bird", "BigBirdModel"), - ("speech_to_text", "Speech2TextModel"), - ("vit", "ViTModel"), - ("wav2vec2", "Wav2Vec2Model"), - ("unispeech-sat", "UniSpeechSatModel"), - ("wavlm", "WavLMModel"), - ("unispeech", "UniSpeechModel"), - ("hubert", "HubertModel"), - ("m2m_100", "M2M100Model"), - ("convbert", "ConvBertModel"), - ("led", "LEDModel"), - ("blenderbot-small", "BlenderbotSmallModel"), - ("retribert", "RetriBertModel"), - ("mt5", "MT5Model"), - ("t5", "T5Model"), - ("pegasus", "PegasusModel"), - ("marian", "MarianModel"), - ("mbart", "MBartModel"), - ("blenderbot", "BlenderbotModel"), - ("distilbert", "DistilBertModel"), ("albert", "AlbertModel"), - ("camembert", "CamembertModel"), - ("xlm-roberta-xl", "XLMRobertaXLModel"), - ("xlm-roberta", "XLMRobertaModel"), ("bart", "BartModel"), - ("opt", "OPTModel"), - ("longformer", "LongformerModel"), - ("roberta", "RobertaModel"), - ("data2vec-text", "Data2VecTextModel"), - ("data2vec-audio", "Data2VecAudioModel"), - ("layoutlm", "LayoutLMModel"), - ("squeezebert", "SqueezeBertModel"), + ("beit", "BeitModel"), ("bert", "BertModel"), - ("openai-gpt", "OpenAIGPTModel"), - ("gpt2", "GPT2Model"), - ("megatron-bert", "MegatronBertModel"), - ("mobilebert", "MobileBertModel"), - ("transfo-xl", "TransfoXLModel"), - ("xlnet", "XLNetModel"), - ("flaubert", "FlaubertModel"), - ("fsmt", "FSMTModel"), - ("xlm", "XLMModel"), - ("ctrl", "CTRLModel"), - ("electra", "ElectraModel"), - ("reformer", "ReformerModel"), - ("funnel", ("FunnelModel", "FunnelBaseModel")), - ("lxmert", "LxmertModel"), ("bert-generation", "BertGenerationEncoder"), + ("big_bird", "BigBirdModel"), + ("bigbird_pegasus", "BigBirdPegasusModel"), + ("blenderbot", "BlenderbotModel"), + ("blenderbot-small", "BlenderbotSmallModel"), + ("camembert", "CamembertModel"), + ("canine", "CanineModel"), + ("clip", "CLIPModel"), + ("convbert", "ConvBertModel"), + ("convnext", "ConvNextModel"), + ("ctrl", "CTRLModel"), + ("data2vec-audio", "Data2VecAudioModel"), + ("data2vec-text", "Data2VecTextModel"), + ("data2vec-vision", "Data2VecVisionModel"), ("deberta", "DebertaModel"), ("deberta-v2", "DebertaV2Model"), + ("decision_transformer", "DecisionTransformerModel"), + ("decision_transformer", "DecisionTransformerModel"), + ("decision_transformer_gpt2", "DecisionTransformerGPT2Model"), + ("deit", "DeiTModel"), + ("detr", "DetrModel"), + ("distilbert", "DistilBertModel"), ("dpr", "DPRQuestionEncoder"), - ("xlm-prophetnet", "XLMProphetNetModel"), - ("prophetnet", "ProphetNetModel"), - ("mpnet", "MPNetModel"), - ("tapas", "TapasModel"), + ("dpt", "DPTModel"), + ("electra", "ElectraModel"), + ("flaubert", "FlaubertModel"), + ("flava", "FlavaModel"), + ("fnet", "FNetModel"), + ("fsmt", "FSMTModel"), + ("funnel", ("FunnelModel", "FunnelBaseModel")), + ("glpn", "GLPNModel"), + ("gpt2", "GPT2Model"), + ("gpt_neo", "GPTNeoModel"), + ("gptj", "GPTJModel"), + ("hubert", "HubertModel"), ("ibert", "IBertModel"), - ("splinter", "SplinterModel"), + ("imagegpt", "ImageGPTModel"), + ("layoutlm", "LayoutLMModel"), + ("layoutlmv2", "LayoutLMv2Model"), + ("led", "LEDModel"), + ("longformer", "LongformerModel"), + ("luke", "LukeModel"), + ("lxmert", "LxmertModel"), + ("m2m_100", "M2M100Model"), + ("marian", "MarianModel"), + ("maskformer", "MaskFormerModel"), + ("mbart", "MBartModel"), + ("megatron-bert", "MegatronBertModel"), + ("mobilebert", "MobileBertModel"), + ("mpnet", "MPNetModel"), + ("mt5", "MT5Model"), + ("nystromformer", "NystromformerModel"), + ("openai-gpt", "OpenAIGPTModel"), + ("opt", "OPTModel"), + ("pegasus", "PegasusModel"), + ("perceiver", "PerceiverModel"), + ("plbart", "PLBartModel"), + ("poolformer", "PoolFormerModel"), + ("prophetnet", "ProphetNetModel"), + ("qdqbert", "QDQBertModel"), + ("reformer", "ReformerModel"), + ("regnet", "RegNetModel"), + ("rembert", "RemBertModel"), + ("resnet", "ResNetModel"), + ("retribert", "RetriBertModel"), + ("roberta", "RobertaModel"), + ("roformer", "RoFormerModel"), + ("segformer", "SegformerModel"), ("sew", "SEWModel"), ("sew-d", "SEWDModel"), + ("speech_to_text", "Speech2TextModel"), + ("splinter", "SplinterModel"), + ("squeezebert", "SqueezeBertModel"), + ("swin", "SwinModel"), + ("t5", "T5Model"), + ("tapas", "TapasModel"), + ("transfo-xl", "TransfoXLModel"), + ("unispeech", "UniSpeechModel"), + ("unispeech-sat", "UniSpeechSatModel"), + ("van", "VanModel"), + ("vilt", "ViltModel"), + ("vision-text-dual-encoder", "VisionTextDualEncoderModel"), + ("visual_bert", "VisualBertModel"), + ("vit", "ViTModel"), + ("vit_mae", "ViTMAEModel"), + ("wav2vec2", "Wav2Vec2Model"), + ("wavlm", "WavLMModel"), + ("xglm", "XGLMModel"), + ("xlm", "XLMModel"), + ("xlm-prophetnet", "XLMProphetNetModel"), + ("xlm-roberta", "XLMRobertaModel"), + ("xlm-roberta-xl", "XLMRobertaXLModel"), + ("xlnet", "XLNetModel"), + ("yolos", "YolosModel"), + ("yoso", "YosoModel"), ] ) MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping - ("flava", "FlavaForPreTraining"), - ("vit_mae", "ViTMAEForPreTraining"), - ("fnet", "FNetForPreTraining"), - ("visual_bert", "VisualBertForPreTraining"), - ("layoutlm", "LayoutLMForMaskedLM"), - ("retribert", "RetriBertModel"), - ("t5", "T5ForConditionalGeneration"), - ("distilbert", "DistilBertForMaskedLM"), ("albert", "AlbertForPreTraining"), - ("camembert", "CamembertForMaskedLM"), - ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), - ("xlm-roberta", "XLMRobertaForMaskedLM"), ("bart", "BartForConditionalGeneration"), - ("fsmt", "FSMTForConditionalGeneration"), - ("longformer", "LongformerForMaskedLM"), - ("roberta", "RobertaForMaskedLM"), - ("data2vec-text", "Data2VecTextForMaskedLM"), - ("squeezebert", "SqueezeBertForMaskedLM"), ("bert", "BertForPreTraining"), ("big_bird", "BigBirdForPreTraining"), - ("openai-gpt", "OpenAIGPTLMHeadModel"), - ("gpt2", "GPT2LMHeadModel"), - ("megatron-bert", "MegatronBertForPreTraining"), - ("mobilebert", "MobileBertForPreTraining"), - ("transfo-xl", "TransfoXLLMHeadModel"), - ("xlnet", "XLNetLMHeadModel"), - ("flaubert", "FlaubertWithLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), + ("camembert", "CamembertForMaskedLM"), ("ctrl", "CTRLLMHeadModel"), - ("electra", "ElectraForPreTraining"), - ("lxmert", "LxmertForPreTraining"), - ("funnel", "FunnelForPreTraining"), - ("mpnet", "MPNetForMaskedLM"), - ("tapas", "TapasForMaskedLM"), - ("ibert", "IBertForMaskedLM"), + ("data2vec-text", "Data2VecTextForMaskedLM"), ("deberta", "DebertaForMaskedLM"), ("deberta-v2", "DebertaV2ForMaskedLM"), - ("wav2vec2", "Wav2Vec2ForPreTraining"), - ("unispeech-sat", "UniSpeechSatForPreTraining"), + ("distilbert", "DistilBertForMaskedLM"), + ("electra", "ElectraForPreTraining"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("flava", "FlavaForPreTraining"), + ("fnet", "FNetForPreTraining"), + ("fsmt", "FSMTForConditionalGeneration"), + ("funnel", "FunnelForPreTraining"), + ("gpt2", "GPT2LMHeadModel"), + ("ibert", "IBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("longformer", "LongformerForMaskedLM"), + ("lxmert", "LxmertForPreTraining"), + ("megatron-bert", "MegatronBertForPreTraining"), + ("mobilebert", "MobileBertForPreTraining"), + ("mpnet", "MPNetForMaskedLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("retribert", "RetriBertModel"), + ("roberta", "RobertaForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("t5", "T5ForConditionalGeneration"), + ("tapas", "TapasForMaskedLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), ("unispeech", "UniSpeechForPreTraining"), + ("unispeech-sat", "UniSpeechSatForPreTraining"), + ("visual_bert", "VisualBertForPreTraining"), + ("vit_mae", "ViTMAEForPreTraining"), + ("wav2vec2", "Wav2Vec2ForPreTraining"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), + ("xlnet", "XLNetLMHeadModel"), ] ) MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping - ("yoso", "YosoForMaskedLM"), - ("nystromformer", "NystromformerForMaskedLM"), - ("plbart", "PLBartForConditionalGeneration"), - ("qdqbert", "QDQBertForMaskedLM"), - ("fnet", "FNetForMaskedLM"), - ("gptj", "GPTJForCausalLM"), - ("rembert", "RemBertForMaskedLM"), - ("roformer", "RoFormerForMaskedLM"), - ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), - ("gpt_neo", "GPTNeoForCausalLM"), - ("big_bird", "BigBirdForMaskedLM"), - ("speech_to_text", "Speech2TextForConditionalGeneration"), - ("wav2vec2", "Wav2Vec2ForMaskedLM"), - ("m2m_100", "M2M100ForConditionalGeneration"), - ("convbert", "ConvBertForMaskedLM"), - ("led", "LEDForConditionalGeneration"), - ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), - ("layoutlm", "LayoutLMForMaskedLM"), - ("t5", "T5ForConditionalGeneration"), - ("distilbert", "DistilBertForMaskedLM"), ("albert", "AlbertForMaskedLM"), - ("camembert", "CamembertForMaskedLM"), - ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), - ("xlm-roberta", "XLMRobertaForMaskedLM"), - ("marian", "MarianMTModel"), - ("fsmt", "FSMTForConditionalGeneration"), ("bart", "BartForConditionalGeneration"), - ("longformer", "LongformerForMaskedLM"), - ("roberta", "RobertaForMaskedLM"), - ("data2vec-text", "Data2VecTextForMaskedLM"), - ("squeezebert", "SqueezeBertForMaskedLM"), ("bert", "BertForMaskedLM"), - ("openai-gpt", "OpenAIGPTLMHeadModel"), - ("gpt2", "GPT2LMHeadModel"), - ("megatron-bert", "MegatronBertForCausalLM"), - ("mobilebert", "MobileBertForMaskedLM"), - ("transfo-xl", "TransfoXLLMHeadModel"), - ("xlnet", "XLNetLMHeadModel"), - ("flaubert", "FlaubertWithLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), + ("big_bird", "BigBirdForMaskedLM"), + ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), + ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), + ("camembert", "CamembertForMaskedLM"), + ("convbert", "ConvBertForMaskedLM"), ("ctrl", "CTRLLMHeadModel"), - ("electra", "ElectraForMaskedLM"), - ("encoder-decoder", "EncoderDecoderModel"), - ("reformer", "ReformerModelWithLMHead"), - ("funnel", "FunnelForMaskedLM"), - ("mpnet", "MPNetForMaskedLM"), - ("tapas", "TapasForMaskedLM"), + ("data2vec-text", "Data2VecTextForMaskedLM"), ("deberta", "DebertaForMaskedLM"), ("deberta-v2", "DebertaV2ForMaskedLM"), + ("distilbert", "DistilBertForMaskedLM"), + ("electra", "ElectraForMaskedLM"), + ("encoder-decoder", "EncoderDecoderModel"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("fnet", "FNetForMaskedLM"), + ("fsmt", "FSMTForConditionalGeneration"), + ("funnel", "FunnelForMaskedLM"), + ("gpt2", "GPT2LMHeadModel"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("gptj", "GPTJForCausalLM"), ("ibert", "IBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("led", "LEDForConditionalGeneration"), + ("longformer", "LongformerForMaskedLM"), + ("m2m_100", "M2M100ForConditionalGeneration"), + ("marian", "MarianMTModel"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), + ("nystromformer", "NystromformerForMaskedLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("plbart", "PLBartForConditionalGeneration"), + ("qdqbert", "QDQBertForMaskedLM"), + ("reformer", "ReformerModelWithLMHead"), + ("rembert", "RemBertForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("speech_to_text", "Speech2TextForConditionalGeneration"), + ("squeezebert", "SqueezeBertForMaskedLM"), + ("t5", "T5ForConditionalGeneration"), + ("tapas", "TapasForMaskedLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), + ("xlnet", "XLNetLMHeadModel"), + ("yoso", "YosoForMaskedLM"), ] ) MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - ("xglm", "XGLMForCausalLM"), - ("plbart", "PLBartForCausalLM"), - ("qdqbert", "QDQBertLMHeadModel"), - ("trocr", "TrOCRForCausalLM"), - ("gptj", "GPTJForCausalLM"), - ("rembert", "RemBertForCausalLM"), - ("roformer", "RoFormerForCausalLM"), - ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), - ("gpt_neo", "GPTNeoForCausalLM"), - ("big_bird", "BigBirdForCausalLM"), - ("camembert", "CamembertForCausalLM"), - ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), - ("xlm-roberta", "XLMRobertaForCausalLM"), - ("roberta", "RobertaForCausalLM"), - ("bert", "BertLMHeadModel"), - ("openai-gpt", "OpenAIGPTLMHeadModel"), - ("gpt2", "GPT2LMHeadModel"), - ("transfo-xl", "TransfoXLLMHeadModel"), - ("xlnet", "XLNetLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), - ("electra", "ElectraForCausalLM"), - ("ctrl", "CTRLLMHeadModel"), - ("reformer", "ReformerModelWithLMHead"), - ("bert-generation", "BertGenerationDecoder"), - ("xlm-prophetnet", "XLMProphetNetForCausalLM"), - ("prophetnet", "ProphetNetForCausalLM"), ("bart", "BartForCausalLM"), - ("opt", "OPTForCausalLM"), - ("mbart", "MBartForCausalLM"), - ("pegasus", "PegasusForCausalLM"), - ("marian", "MarianForCausalLM"), + ("bert", "BertLMHeadModel"), + ("bert-generation", "BertGenerationDecoder"), + ("big_bird", "BigBirdForCausalLM"), + ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), ("blenderbot", "BlenderbotForCausalLM"), ("blenderbot-small", "BlenderbotSmallForCausalLM"), - ("megatron-bert", "MegatronBertForCausalLM"), - ("speech_to_text_2", "Speech2Text2ForCausalLM"), + ("camembert", "CamembertForCausalLM"), + ("ctrl", "CTRLLMHeadModel"), ("data2vec-text", "Data2VecTextForCausalLM"), + ("electra", "ElectraForCausalLM"), + ("gpt2", "GPT2LMHeadModel"), + ("gpt_neo", "GPTNeoForCausalLM"), + ("gptj", "GPTJForCausalLM"), + ("marian", "MarianForCausalLM"), + ("mbart", "MBartForCausalLM"), + ("megatron-bert", "MegatronBertForCausalLM"), + ("openai-gpt", "OpenAIGPTLMHeadModel"), + ("opt", "OPTForCausalLM"), + ("pegasus", "PegasusForCausalLM"), + ("plbart", "PLBartForCausalLM"), + ("prophetnet", "ProphetNetForCausalLM"), + ("qdqbert", "QDQBertLMHeadModel"), + ("reformer", "ReformerModelWithLMHead"), + ("rembert", "RemBertForCausalLM"), + ("roberta", "RobertaForCausalLM"), + ("roformer", "RoFormerForCausalLM"), + ("speech_to_text_2", "Speech2Text2ForCausalLM"), + ("transfo-xl", "TransfoXLLMHeadModel"), + ("trocr", "TrOCRForCausalLM"), + ("xglm", "XGLMForCausalLM"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-prophetnet", "XLMProphetNetForCausalLM"), + ("xlm-roberta", "XLMRobertaForCausalLM"), + ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"), + ("xlnet", "XLNetLMHeadModel"), ] ) MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( [ - ("vit", "ViTForMaskedImageModeling"), ("deit", "DeiTForMaskedImageModeling"), ("swin", "SwinForMaskedImageModeling"), + ("vit", "ViTForMaskedImageModeling"), ] ) @@ -293,11 +293,10 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict( MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image Classification mapping - ("vit", "ViTForImageClassification"), - ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), ("beit", "BeitForImageClassification"), + ("convnext", "ConvNextForImageClassification"), ("data2vec-vision", "Data2VecVisionForImageClassification"), - ("segformer", "SegformerForImageClassification"), + ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), ("imagegpt", "ImageGPTForImageClassification"), ( "perceiver", @@ -307,12 +306,13 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( "PerceiverForImageClassificationConvProcessing", ), ), - ("swin", "SwinForImageClassification"), - ("convnext", "ConvNextForImageClassification"), - ("van", "VanForImageClassification"), - ("resnet", "ResNetForImageClassification"), - ("regnet", "RegNetForImageClassification"), ("poolformer", "PoolFormerForImageClassification"), + ("regnet", "RegNetForImageClassification"), + ("resnet", "ResNetForImageClassification"), + ("segformer", "SegformerForImageClassification"), + ("swin", "SwinForImageClassification"), + ("van", "VanForImageClassification"), + ("vit", "ViTForImageClassification"), ] ) @@ -329,8 +329,8 @@ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( # Model for Semantic Segmentation mapping ("beit", "BeitForSemanticSegmentation"), ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"), - ("segformer", "SegformerForSemanticSegmentation"), ("dpt", "DPTForSemanticSegmentation"), + ("segformer", "SegformerForSemanticSegmentation"), ] ) @@ -350,72 +350,71 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - ("yoso", "YosoForMaskedLM"), + ("albert", "AlbertForMaskedLM"), + ("bart", "BartForConditionalGeneration"), + ("bert", "BertForMaskedLM"), + ("big_bird", "BigBirdForMaskedLM"), + ("camembert", "CamembertForMaskedLM"), + ("convbert", "ConvBertForMaskedLM"), + ("data2vec-text", "Data2VecTextForMaskedLM"), + ("deberta", "DebertaForMaskedLM"), + ("deberta-v2", "DebertaV2ForMaskedLM"), + ("distilbert", "DistilBertForMaskedLM"), + ("electra", "ElectraForMaskedLM"), + ("flaubert", "FlaubertWithLMHeadModel"), + ("fnet", "FNetForMaskedLM"), + ("funnel", "FunnelForMaskedLM"), + ("ibert", "IBertForMaskedLM"), + ("layoutlm", "LayoutLMForMaskedLM"), + ("longformer", "LongformerForMaskedLM"), + ("mbart", "MBartForConditionalGeneration"), + ("megatron-bert", "MegatronBertForMaskedLM"), + ("mobilebert", "MobileBertForMaskedLM"), + ("mpnet", "MPNetForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"), ("perceiver", "PerceiverForMaskedLM"), ("qdqbert", "QDQBertForMaskedLM"), - ("fnet", "FNetForMaskedLM"), - ("rembert", "RemBertForMaskedLM"), - ("roformer", "RoFormerForMaskedLM"), - ("big_bird", "BigBirdForMaskedLM"), - ("wav2vec2", "Wav2Vec2ForMaskedLM"), - ("convbert", "ConvBertForMaskedLM"), - ("layoutlm", "LayoutLMForMaskedLM"), - ("distilbert", "DistilBertForMaskedLM"), - ("albert", "AlbertForMaskedLM"), - ("bart", "BartForConditionalGeneration"), - ("mbart", "MBartForConditionalGeneration"), - ("camembert", "CamembertForMaskedLM"), - ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), - ("xlm-roberta", "XLMRobertaForMaskedLM"), - ("longformer", "LongformerForMaskedLM"), - ("roberta", "RobertaForMaskedLM"), - ("data2vec-text", "Data2VecTextForMaskedLM"), - ("squeezebert", "SqueezeBertForMaskedLM"), - ("bert", "BertForMaskedLM"), - ("megatron-bert", "MegatronBertForMaskedLM"), - ("mobilebert", "MobileBertForMaskedLM"), - ("flaubert", "FlaubertWithLMHeadModel"), - ("xlm", "XLMWithLMHeadModel"), - ("electra", "ElectraForMaskedLM"), ("reformer", "ReformerForMaskedLM"), - ("funnel", "FunnelForMaskedLM"), - ("mpnet", "MPNetForMaskedLM"), + ("rembert", "RemBertForMaskedLM"), + ("roberta", "RobertaForMaskedLM"), + ("roformer", "RoFormerForMaskedLM"), + ("squeezebert", "SqueezeBertForMaskedLM"), ("tapas", "TapasForMaskedLM"), - ("deberta", "DebertaForMaskedLM"), - ("deberta-v2", "DebertaV2ForMaskedLM"), - ("ibert", "IBertForMaskedLM"), + ("wav2vec2", "Wav2Vec2ForMaskedLM"), + ("xlm", "XLMWithLMHeadModel"), + ("xlm-roberta", "XLMRobertaForMaskedLM"), + ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"), + ("yoso", "YosoForMaskedLM"), ] ) MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( [ # Model for Object Detection mapping - ("yolos", "YolosForObjectDetection"), ("detr", "DetrForObjectDetection"), + ("yolos", "YolosForObjectDetection"), ] ) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - ("tapex", "BartForConditionalGeneration"), - ("plbart", "PLBartForConditionalGeneration"), + ("bart", "BartForConditionalGeneration"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), - ("m2m_100", "M2M100ForConditionalGeneration"), - ("led", "LEDForConditionalGeneration"), + ("blenderbot", "BlenderbotForConditionalGeneration"), ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), - ("mt5", "MT5ForConditionalGeneration"), - ("t5", "T5ForConditionalGeneration"), - ("pegasus", "PegasusForConditionalGeneration"), + ("encoder-decoder", "EncoderDecoderModel"), + ("fsmt", "FSMTForConditionalGeneration"), + ("led", "LEDForConditionalGeneration"), + ("m2m_100", "M2M100ForConditionalGeneration"), ("marian", "MarianMTModel"), ("mbart", "MBartForConditionalGeneration"), - ("blenderbot", "BlenderbotForConditionalGeneration"), - ("bart", "BartForConditionalGeneration"), - ("fsmt", "FSMTForConditionalGeneration"), - ("encoder-decoder", "EncoderDecoderModel"), - ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), + ("mt5", "MT5ForConditionalGeneration"), + ("pegasus", "PegasusForConditionalGeneration"), + ("plbart", "PLBartForConditionalGeneration"), ("prophetnet", "ProphetNetForConditionalGeneration"), + ("t5", "T5ForConditionalGeneration"), + ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] ) @@ -429,98 +428,97 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - ("tapex", "BartForSequenceClassification"), - ("yoso", "YosoForSequenceClassification"), - ("nystromformer", "NystromformerForSequenceClassification"), - ("plbart", "PLBartForSequenceClassification"), - ("perceiver", "PerceiverForSequenceClassification"), - ("qdqbert", "QDQBertForSequenceClassification"), - ("fnet", "FNetForSequenceClassification"), - ("gptj", "GPTJForSequenceClassification"), - ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), - ("rembert", "RemBertForSequenceClassification"), - ("canine", "CanineForSequenceClassification"), - ("roformer", "RoFormerForSequenceClassification"), - ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), - ("big_bird", "BigBirdForSequenceClassification"), - ("convbert", "ConvBertForSequenceClassification"), - ("led", "LEDForSequenceClassification"), - ("distilbert", "DistilBertForSequenceClassification"), ("albert", "AlbertForSequenceClassification"), - ("camembert", "CamembertForSequenceClassification"), - ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"), - ("xlm-roberta", "XLMRobertaForSequenceClassification"), - ("mbart", "MBartForSequenceClassification"), ("bart", "BartForSequenceClassification"), - ("longformer", "LongformerForSequenceClassification"), - ("roberta", "RobertaForSequenceClassification"), - ("data2vec-text", "Data2VecTextForSequenceClassification"), - ("squeezebert", "SqueezeBertForSequenceClassification"), - ("layoutlm", "LayoutLMForSequenceClassification"), ("bert", "BertForSequenceClassification"), - ("xlnet", "XLNetForSequenceClassification"), - ("megatron-bert", "MegatronBertForSequenceClassification"), - ("mobilebert", "MobileBertForSequenceClassification"), - ("flaubert", "FlaubertForSequenceClassification"), - ("xlm", "XLMForSequenceClassification"), - ("electra", "ElectraForSequenceClassification"), - ("funnel", "FunnelForSequenceClassification"), + ("big_bird", "BigBirdForSequenceClassification"), + ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), + ("camembert", "CamembertForSequenceClassification"), + ("canine", "CanineForSequenceClassification"), + ("convbert", "ConvBertForSequenceClassification"), + ("ctrl", "CTRLForSequenceClassification"), + ("data2vec-text", "Data2VecTextForSequenceClassification"), ("deberta", "DebertaForSequenceClassification"), ("deberta-v2", "DebertaV2ForSequenceClassification"), + ("distilbert", "DistilBertForSequenceClassification"), + ("electra", "ElectraForSequenceClassification"), + ("flaubert", "FlaubertForSequenceClassification"), + ("fnet", "FNetForSequenceClassification"), + ("funnel", "FunnelForSequenceClassification"), ("gpt2", "GPT2ForSequenceClassification"), ("gpt_neo", "GPTNeoForSequenceClassification"), - ("openai-gpt", "OpenAIGPTForSequenceClassification"), - ("reformer", "ReformerForSequenceClassification"), - ("ctrl", "CTRLForSequenceClassification"), - ("transfo-xl", "TransfoXLForSequenceClassification"), - ("mpnet", "MPNetForSequenceClassification"), - ("tapas", "TapasForSequenceClassification"), + ("gptj", "GPTJForSequenceClassification"), ("ibert", "IBertForSequenceClassification"), + ("layoutlm", "LayoutLMForSequenceClassification"), + ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), + ("led", "LEDForSequenceClassification"), + ("longformer", "LongformerForSequenceClassification"), + ("mbart", "MBartForSequenceClassification"), + ("megatron-bert", "MegatronBertForSequenceClassification"), + ("mobilebert", "MobileBertForSequenceClassification"), + ("mpnet", "MPNetForSequenceClassification"), + ("nystromformer", "NystromformerForSequenceClassification"), + ("openai-gpt", "OpenAIGPTForSequenceClassification"), + ("perceiver", "PerceiverForSequenceClassification"), + ("plbart", "PLBartForSequenceClassification"), + ("qdqbert", "QDQBertForSequenceClassification"), + ("reformer", "ReformerForSequenceClassification"), + ("rembert", "RemBertForSequenceClassification"), + ("roberta", "RobertaForSequenceClassification"), + ("roformer", "RoFormerForSequenceClassification"), + ("squeezebert", "SqueezeBertForSequenceClassification"), + ("tapas", "TapasForSequenceClassification"), + ("transfo-xl", "TransfoXLForSequenceClassification"), + ("xlm", "XLMForSequenceClassification"), + ("xlm-roberta", "XLMRobertaForSequenceClassification"), + ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"), + ("xlnet", "XLNetForSequenceClassification"), + ("yoso", "YosoForSequenceClassification"), ] ) MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - ("yoso", "YosoForQuestionAnswering"), - ("nystromformer", "NystromformerForQuestionAnswering"), - ("qdqbert", "QDQBertForQuestionAnswering"), - ("fnet", "FNetForQuestionAnswering"), - ("gptj", "GPTJForQuestionAnswering"), - ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), - ("rembert", "RemBertForQuestionAnswering"), - ("canine", "CanineForQuestionAnswering"), - ("roformer", "RoFormerForQuestionAnswering"), - ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), - ("big_bird", "BigBirdForQuestionAnswering"), - ("convbert", "ConvBertForQuestionAnswering"), - ("led", "LEDForQuestionAnswering"), - ("distilbert", "DistilBertForQuestionAnswering"), ("albert", "AlbertForQuestionAnswering"), - ("camembert", "CamembertForQuestionAnswering"), ("bart", "BartForQuestionAnswering"), - ("mbart", "MBartForQuestionAnswering"), - ("longformer", "LongformerForQuestionAnswering"), - ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), - ("xlm-roberta", "XLMRobertaForQuestionAnswering"), - ("roberta", "RobertaForQuestionAnswering"), - ("squeezebert", "SqueezeBertForQuestionAnswering"), ("bert", "BertForQuestionAnswering"), - ("xlnet", "XLNetForQuestionAnsweringSimple"), - ("flaubert", "FlaubertForQuestionAnsweringSimple"), - ("megatron-bert", "MegatronBertForQuestionAnswering"), - ("mobilebert", "MobileBertForQuestionAnswering"), - ("xlm", "XLMForQuestionAnsweringSimple"), - ("electra", "ElectraForQuestionAnswering"), - ("reformer", "ReformerForQuestionAnswering"), - ("funnel", "FunnelForQuestionAnswering"), - ("lxmert", "LxmertForQuestionAnswering"), - ("mpnet", "MPNetForQuestionAnswering"), + ("big_bird", "BigBirdForQuestionAnswering"), + ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"), + ("camembert", "CamembertForQuestionAnswering"), + ("canine", "CanineForQuestionAnswering"), + ("convbert", "ConvBertForQuestionAnswering"), + ("data2vec-text", "Data2VecTextForQuestionAnswering"), ("deberta", "DebertaForQuestionAnswering"), ("deberta-v2", "DebertaV2ForQuestionAnswering"), + ("distilbert", "DistilBertForQuestionAnswering"), + ("electra", "ElectraForQuestionAnswering"), + ("flaubert", "FlaubertForQuestionAnsweringSimple"), + ("fnet", "FNetForQuestionAnswering"), + ("funnel", "FunnelForQuestionAnswering"), + ("gptj", "GPTJForQuestionAnswering"), ("ibert", "IBertForQuestionAnswering"), + ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), + ("led", "LEDForQuestionAnswering"), + ("longformer", "LongformerForQuestionAnswering"), + ("lxmert", "LxmertForQuestionAnswering"), + ("mbart", "MBartForQuestionAnswering"), + ("megatron-bert", "MegatronBertForQuestionAnswering"), + ("mobilebert", "MobileBertForQuestionAnswering"), + ("mpnet", "MPNetForQuestionAnswering"), + ("nystromformer", "NystromformerForQuestionAnswering"), + ("qdqbert", "QDQBertForQuestionAnswering"), + ("reformer", "ReformerForQuestionAnswering"), + ("rembert", "RemBertForQuestionAnswering"), + ("roberta", "RobertaForQuestionAnswering"), + ("roformer", "RoFormerForQuestionAnswering"), ("splinter", "SplinterForQuestionAnswering"), - ("data2vec-text", "Data2VecTextForQuestionAnswering"), + ("squeezebert", "SqueezeBertForQuestionAnswering"), + ("xlm", "XLMForQuestionAnsweringSimple"), + ("xlm-roberta", "XLMRobertaForQuestionAnswering"), + ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"), + ("xlnet", "XLNetForQuestionAnsweringSimple"), + ("yoso", "YosoForQuestionAnswering"), ] ) @@ -534,132 +532,132 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - ("yoso", "YosoForTokenClassification"), - ("nystromformer", "NystromformerForTokenClassification"), - ("qdqbert", "QDQBertForTokenClassification"), - ("fnet", "FNetForTokenClassification"), - ("layoutlmv2", "LayoutLMv2ForTokenClassification"), - ("rembert", "RemBertForTokenClassification"), - ("canine", "CanineForTokenClassification"), - ("roformer", "RoFormerForTokenClassification"), - ("big_bird", "BigBirdForTokenClassification"), - ("convbert", "ConvBertForTokenClassification"), - ("layoutlm", "LayoutLMForTokenClassification"), - ("distilbert", "DistilBertForTokenClassification"), - ("camembert", "CamembertForTokenClassification"), - ("flaubert", "FlaubertForTokenClassification"), - ("xlm", "XLMForTokenClassification"), - ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"), - ("xlm-roberta", "XLMRobertaForTokenClassification"), - ("longformer", "LongformerForTokenClassification"), - ("roberta", "RobertaForTokenClassification"), - ("squeezebert", "SqueezeBertForTokenClassification"), - ("bert", "BertForTokenClassification"), - ("megatron-bert", "MegatronBertForTokenClassification"), - ("mobilebert", "MobileBertForTokenClassification"), - ("xlnet", "XLNetForTokenClassification"), ("albert", "AlbertForTokenClassification"), - ("electra", "ElectraForTokenClassification"), - ("funnel", "FunnelForTokenClassification"), - ("mpnet", "MPNetForTokenClassification"), + ("bert", "BertForTokenClassification"), + ("big_bird", "BigBirdForTokenClassification"), + ("camembert", "CamembertForTokenClassification"), + ("canine", "CanineForTokenClassification"), + ("convbert", "ConvBertForTokenClassification"), + ("data2vec-text", "Data2VecTextForTokenClassification"), ("deberta", "DebertaForTokenClassification"), ("deberta-v2", "DebertaV2ForTokenClassification"), + ("distilbert", "DistilBertForTokenClassification"), + ("electra", "ElectraForTokenClassification"), + ("flaubert", "FlaubertForTokenClassification"), + ("fnet", "FNetForTokenClassification"), + ("funnel", "FunnelForTokenClassification"), ("gpt2", "GPT2ForTokenClassification"), ("ibert", "IBertForTokenClassification"), - ("data2vec-text", "Data2VecTextForTokenClassification"), + ("layoutlm", "LayoutLMForTokenClassification"), + ("layoutlmv2", "LayoutLMv2ForTokenClassification"), + ("longformer", "LongformerForTokenClassification"), + ("megatron-bert", "MegatronBertForTokenClassification"), + ("mobilebert", "MobileBertForTokenClassification"), + ("mpnet", "MPNetForTokenClassification"), + ("nystromformer", "NystromformerForTokenClassification"), + ("qdqbert", "QDQBertForTokenClassification"), + ("rembert", "RemBertForTokenClassification"), + ("roberta", "RobertaForTokenClassification"), + ("roformer", "RoFormerForTokenClassification"), + ("squeezebert", "SqueezeBertForTokenClassification"), + ("xlm", "XLMForTokenClassification"), + ("xlm-roberta", "XLMRobertaForTokenClassification"), + ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"), + ("xlnet", "XLNetForTokenClassification"), + ("yoso", "YosoForTokenClassification"), ] ) MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - ("yoso", "YosoForMultipleChoice"), - ("nystromformer", "NystromformerForMultipleChoice"), - ("qdqbert", "QDQBertForMultipleChoice"), - ("fnet", "FNetForMultipleChoice"), - ("rembert", "RemBertForMultipleChoice"), - ("canine", "CanineForMultipleChoice"), - ("roformer", "RoFormerForMultipleChoice"), - ("big_bird", "BigBirdForMultipleChoice"), - ("convbert", "ConvBertForMultipleChoice"), - ("camembert", "CamembertForMultipleChoice"), - ("electra", "ElectraForMultipleChoice"), - ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"), - ("xlm-roberta", "XLMRobertaForMultipleChoice"), - ("longformer", "LongformerForMultipleChoice"), - ("roberta", "RobertaForMultipleChoice"), - ("data2vec-text", "Data2VecTextForMultipleChoice"), - ("squeezebert", "SqueezeBertForMultipleChoice"), + ("albert", "AlbertForMultipleChoice"), ("bert", "BertForMultipleChoice"), + ("big_bird", "BigBirdForMultipleChoice"), + ("camembert", "CamembertForMultipleChoice"), + ("canine", "CanineForMultipleChoice"), + ("convbert", "ConvBertForMultipleChoice"), + ("data2vec-text", "Data2VecTextForMultipleChoice"), + ("deberta-v2", "DebertaV2ForMultipleChoice"), ("distilbert", "DistilBertForMultipleChoice"), + ("electra", "ElectraForMultipleChoice"), + ("flaubert", "FlaubertForMultipleChoice"), + ("fnet", "FNetForMultipleChoice"), + ("funnel", "FunnelForMultipleChoice"), + ("ibert", "IBertForMultipleChoice"), + ("longformer", "LongformerForMultipleChoice"), ("megatron-bert", "MegatronBertForMultipleChoice"), ("mobilebert", "MobileBertForMultipleChoice"), - ("xlnet", "XLNetForMultipleChoice"), - ("albert", "AlbertForMultipleChoice"), - ("xlm", "XLMForMultipleChoice"), - ("flaubert", "FlaubertForMultipleChoice"), - ("funnel", "FunnelForMultipleChoice"), ("mpnet", "MPNetForMultipleChoice"), - ("ibert", "IBertForMultipleChoice"), - ("deberta-v2", "DebertaV2ForMultipleChoice"), + ("nystromformer", "NystromformerForMultipleChoice"), + ("qdqbert", "QDQBertForMultipleChoice"), + ("rembert", "RemBertForMultipleChoice"), + ("roberta", "RobertaForMultipleChoice"), + ("roformer", "RoFormerForMultipleChoice"), + ("squeezebert", "SqueezeBertForMultipleChoice"), + ("xlm", "XLMForMultipleChoice"), + ("xlm-roberta", "XLMRobertaForMultipleChoice"), + ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"), + ("xlnet", "XLNetForMultipleChoice"), + ("yoso", "YosoForMultipleChoice"), ] ) MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict( [ - ("qdqbert", "QDQBertForNextSentencePrediction"), ("bert", "BertForNextSentencePrediction"), ("fnet", "FNetForNextSentencePrediction"), ("megatron-bert", "MegatronBertForNextSentencePrediction"), ("mobilebert", "MobileBertForNextSentencePrediction"), + ("qdqbert", "QDQBertForNextSentencePrediction"), ] ) MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Audio Classification mapping - ("wav2vec2", "Wav2Vec2ForSequenceClassification"), - ("unispeech-sat", "UniSpeechSatForSequenceClassification"), - ("unispeech", "UniSpeechForSequenceClassification"), + ("data2vec-audio", "Data2VecAudioForSequenceClassification"), ("hubert", "HubertForSequenceClassification"), ("sew", "SEWForSequenceClassification"), ("sew-d", "SEWDForSequenceClassification"), + ("unispeech", "UniSpeechForSequenceClassification"), + ("unispeech-sat", "UniSpeechSatForSequenceClassification"), + ("wav2vec2", "Wav2Vec2ForSequenceClassification"), ("wavlm", "WavLMForSequenceClassification"), - ("data2vec-audio", "Data2VecAudioForSequenceClassification"), ] ) MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( [ # Model for Connectionist temporal classification (CTC) mapping - ("wav2vec2", "Wav2Vec2ForCTC"), - ("unispeech-sat", "UniSpeechSatForCTC"), - ("unispeech", "UniSpeechForCTC"), + ("data2vec-audio", "Data2VecAudioForCTC"), ("hubert", "HubertForCTC"), ("sew", "SEWForCTC"), ("sew-d", "SEWDForCTC"), + ("unispeech", "UniSpeechForCTC"), + ("unispeech-sat", "UniSpeechSatForCTC"), + ("wav2vec2", "Wav2Vec2ForCTC"), ("wavlm", "WavLMForCTC"), - ("data2vec-audio", "Data2VecAudioForCTC"), ] ) MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Audio Classification mapping - ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), - ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), - ("wavlm", "WavLMForAudioFrameClassification"), ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"), + ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), + ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), + ("wavlm", "WavLMForAudioFrameClassification"), ] ) MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( [ # Model for Audio Classification mapping - ("wav2vec2", "Wav2Vec2ForXVector"), - ("unispeech-sat", "UniSpeechSatForXVector"), - ("wavlm", "WavLMForXVector"), ("data2vec-audio", "Data2VecAudioForXVector"), + ("unispeech-sat", "UniSpeechSatForXVector"), + ("wav2vec2", "Wav2Vec2ForXVector"), + ("wavlm", "WavLMForXVector"), ] ) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 78803178bec..230ffcbce34 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -28,31 +28,31 @@ logger = logging.get_logger(__name__) FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping - ("xglm", "FlaxXGLMModel"), - ("blenderbot-small", "FlaxBlenderbotSmallModel"), - ("pegasus", "FlaxPegasusModel"), - ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), - ("distilbert", "FlaxDistilBertModel"), ("albert", "FlaxAlbertModel"), - ("roberta", "FlaxRobertaModel"), - ("xlm-roberta", "FlaxXLMRobertaModel"), - ("bert", "FlaxBertModel"), - ("beit", "FlaxBeitModel"), - ("big_bird", "FlaxBigBirdModel"), ("bart", "FlaxBartModel"), + ("beit", "FlaxBeitModel"), + ("bert", "FlaxBertModel"), + ("big_bird", "FlaxBigBirdModel"), + ("blenderbot", "FlaxBlenderbotModel"), + ("blenderbot-small", "FlaxBlenderbotSmallModel"), + ("clip", "FlaxCLIPModel"), + ("distilbert", "FlaxDistilBertModel"), + ("electra", "FlaxElectraModel"), ("gpt2", "FlaxGPT2Model"), ("gpt_neo", "FlaxGPTNeoModel"), ("gptj", "FlaxGPTJModel"), - ("electra", "FlaxElectraModel"), - ("clip", "FlaxCLIPModel"), - ("vit", "FlaxViTModel"), - ("mbart", "FlaxMBartModel"), - ("t5", "FlaxT5Model"), - ("mt5", "FlaxMT5Model"), - ("wav2vec2", "FlaxWav2Vec2Model"), ("marian", "FlaxMarianModel"), - ("blenderbot", "FlaxBlenderbotModel"), + ("mbart", "FlaxMBartModel"), + ("mt5", "FlaxMT5Model"), + ("pegasus", "FlaxPegasusModel"), + ("roberta", "FlaxRobertaModel"), ("roformer", "FlaxRoFormerModel"), + ("t5", "FlaxT5Model"), + ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), + ("vit", "FlaxViTModel"), + ("wav2vec2", "FlaxWav2Vec2Model"), + ("xglm", "FlaxXGLMModel"), + ("xlm-roberta", "FlaxXLMRobertaModel"), ] ) @@ -60,56 +60,56 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping ("albert", "FlaxAlbertForPreTraining"), - ("roberta", "FlaxRobertaForMaskedLM"), - ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), + ("bart", "FlaxBartForConditionalGeneration"), ("bert", "FlaxBertForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"), - ("bart", "FlaxBartForConditionalGeneration"), ("electra", "FlaxElectraForPreTraining"), ("mbart", "FlaxMBartForConditionalGeneration"), - ("t5", "FlaxT5ForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), - ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), + ("roberta", "FlaxRobertaForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), + ("t5", "FlaxT5ForConditionalGeneration"), + ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), + ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ] ) FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - ("distilbert", "FlaxDistilBertForMaskedLM"), ("albert", "FlaxAlbertForMaskedLM"), - ("roberta", "FlaxRobertaForMaskedLM"), - ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), + ("bart", "FlaxBartForConditionalGeneration"), ("bert", "FlaxBertForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"), - ("bart", "FlaxBartForConditionalGeneration"), + ("distilbert", "FlaxDistilBertForMaskedLM"), ("electra", "FlaxElectraForMaskedLM"), ("mbart", "FlaxMBartForConditionalGeneration"), + ("roberta", "FlaxRobertaForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"), + ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ] ) FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), - ("pegasus", "FlaxPegasusForConditionalGeneration"), ("bart", "FlaxBartForConditionalGeneration"), - ("mbart", "FlaxMBartForConditionalGeneration"), - ("t5", "FlaxT5ForConditionalGeneration"), - ("mt5", "FlaxMT5ForConditionalGeneration"), - ("marian", "FlaxMarianMTModel"), - ("encoder-decoder", "FlaxEncoderDecoderModel"), ("blenderbot", "FlaxBlenderbotForConditionalGeneration"), + ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"), + ("encoder-decoder", "FlaxEncoderDecoderModel"), + ("marian", "FlaxMarianMTModel"), + ("mbart", "FlaxMBartForConditionalGeneration"), + ("mt5", "FlaxMT5ForConditionalGeneration"), + ("pegasus", "FlaxPegasusForConditionalGeneration"), + ("t5", "FlaxT5ForConditionalGeneration"), ] ) FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image-classsification - ("vit", "FlaxViTForImageClassification"), ("beit", "FlaxBeitForImageClassification"), + ("vit", "FlaxViTForImageClassification"), ] ) @@ -122,75 +122,75 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("bart", "FlaxBartForCausalLM"), + ("bert", "FlaxBertForCausalLM"), + ("big_bird", "FlaxBigBirdForCausalLM"), + ("electra", "FlaxElectraForCausalLM"), ("gpt2", "FlaxGPT2LMHeadModel"), ("gpt_neo", "FlaxGPTNeoForCausalLM"), ("gptj", "FlaxGPTJForCausalLM"), - ("xglm", "FlaxXGLMForCausalLM"), - ("bart", "FlaxBartForCausalLM"), - ("bert", "FlaxBertForCausalLM"), ("roberta", "FlaxRobertaForCausalLM"), - ("big_bird", "FlaxBigBirdForCausalLM"), - ("electra", "FlaxElectraForCausalLM"), + ("xglm", "FlaxXGLMForCausalLM"), ] ) FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - ("distilbert", "FlaxDistilBertForSequenceClassification"), ("albert", "FlaxAlbertForSequenceClassification"), - ("roberta", "FlaxRobertaForSequenceClassification"), - ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), + ("bart", "FlaxBartForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"), - ("bart", "FlaxBartForSequenceClassification"), + ("distilbert", "FlaxDistilBertForSequenceClassification"), ("electra", "FlaxElectraForSequenceClassification"), ("mbart", "FlaxMBartForSequenceClassification"), + ("roberta", "FlaxRobertaForSequenceClassification"), ("roformer", "FlaxRoFormerForSequenceClassification"), + ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"), ] ) FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - ("distilbert", "FlaxDistilBertForQuestionAnswering"), ("albert", "FlaxAlbertForQuestionAnswering"), - ("roberta", "FlaxRobertaForQuestionAnswering"), - ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), + ("bart", "FlaxBartForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"), - ("bart", "FlaxBartForQuestionAnswering"), + ("distilbert", "FlaxDistilBertForQuestionAnswering"), ("electra", "FlaxElectraForQuestionAnswering"), ("mbart", "FlaxMBartForQuestionAnswering"), + ("roberta", "FlaxRobertaForQuestionAnswering"), ("roformer", "FlaxRoFormerForQuestionAnswering"), + ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"), ] ) FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - ("distilbert", "FlaxDistilBertForTokenClassification"), ("albert", "FlaxAlbertForTokenClassification"), - ("roberta", "FlaxRobertaForTokenClassification"), - ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), ("bert", "FlaxBertForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"), + ("distilbert", "FlaxDistilBertForTokenClassification"), ("electra", "FlaxElectraForTokenClassification"), + ("roberta", "FlaxRobertaForTokenClassification"), ("roformer", "FlaxRoFormerForTokenClassification"), + ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"), ] ) FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - ("distilbert", "FlaxDistilBertForMultipleChoice"), ("albert", "FlaxAlbertForMultipleChoice"), - ("roberta", "FlaxRobertaForMultipleChoice"), - ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"), + ("distilbert", "FlaxDistilBertForMultipleChoice"), ("electra", "FlaxElectraForMultipleChoice"), + ("roberta", "FlaxRobertaForMultipleChoice"), ("roformer", "FlaxRoFormerForMultipleChoice"), + ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"), ] ) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 456d1426dc2..3fe804cb4c5 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -29,142 +29,142 @@ logger = logging.get_logger(__name__) TF_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping - ("speech_to_text", "TFSpeech2TextModel"), - ("clip", "TFCLIPModel"), - ("deberta-v2", "TFDebertaV2Model"), - ("deberta", "TFDebertaModel"), - ("rembert", "TFRemBertModel"), - ("roformer", "TFRoFormerModel"), - ("convbert", "TFConvBertModel"), - ("convnext", "TFConvNextModel"), - ("data2vec-vision", "TFData2VecVisionModel"), - ("led", "TFLEDModel"), - ("lxmert", "TFLxmertModel"), - ("mt5", "TFMT5Model"), - ("t5", "TFT5Model"), - ("distilbert", "TFDistilBertModel"), ("albert", "TFAlbertModel"), ("bart", "TFBartModel"), - ("camembert", "TFCamembertModel"), - ("xlm-roberta", "TFXLMRobertaModel"), - ("longformer", "TFLongformerModel"), - ("roberta", "TFRobertaModel"), - ("layoutlm", "TFLayoutLMModel"), ("bert", "TFBertModel"), - ("openai-gpt", "TFOpenAIGPTModel"), - ("gpt2", "TFGPT2Model"), - ("gptj", "TFGPTJModel"), - ("mobilebert", "TFMobileBertModel"), - ("transfo-xl", "TFTransfoXLModel"), - ("xlnet", "TFXLNetModel"), - ("flaubert", "TFFlaubertModel"), - ("xlm", "TFXLMModel"), - ("ctrl", "TFCTRLModel"), - ("electra", "TFElectraModel"), - ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), - ("dpr", "TFDPRQuestionEncoder"), - ("mpnet", "TFMPNetModel"), - ("tapas", "TFTapasModel"), - ("mbart", "TFMBartModel"), - ("marian", "TFMarianModel"), - ("pegasus", "TFPegasusModel"), ("blenderbot", "TFBlenderbotModel"), ("blenderbot-small", "TFBlenderbotSmallModel"), + ("camembert", "TFCamembertModel"), + ("clip", "TFCLIPModel"), + ("convbert", "TFConvBertModel"), + ("convnext", "TFConvNextModel"), + ("ctrl", "TFCTRLModel"), + ("data2vec-vision", "TFData2VecVisionModel"), + ("deberta", "TFDebertaModel"), + ("deberta-v2", "TFDebertaV2Model"), + ("distilbert", "TFDistilBertModel"), + ("dpr", "TFDPRQuestionEncoder"), + ("electra", "TFElectraModel"), + ("flaubert", "TFFlaubertModel"), + ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")), + ("gpt2", "TFGPT2Model"), + ("gptj", "TFGPTJModel"), + ("hubert", "TFHubertModel"), + ("layoutlm", "TFLayoutLMModel"), + ("led", "TFLEDModel"), + ("longformer", "TFLongformerModel"), + ("lxmert", "TFLxmertModel"), + ("marian", "TFMarianModel"), + ("mbart", "TFMBartModel"), + ("mobilebert", "TFMobileBertModel"), + ("mpnet", "TFMPNetModel"), + ("mt5", "TFMT5Model"), + ("openai-gpt", "TFOpenAIGPTModel"), + ("pegasus", "TFPegasusModel"), + ("rembert", "TFRemBertModel"), + ("roberta", "TFRobertaModel"), + ("roformer", "TFRoFormerModel"), + ("speech_to_text", "TFSpeech2TextModel"), + ("t5", "TFT5Model"), + ("tapas", "TFTapasModel"), + ("transfo-xl", "TFTransfoXLModel"), ("vit", "TFViTModel"), ("vit_mae", "TFViTMAEModel"), ("wav2vec2", "TFWav2Vec2Model"), - ("hubert", "TFHubertModel"), + ("xlm", "TFXLMModel"), + ("xlm-roberta", "TFXLMRobertaModel"), + ("xlnet", "TFXLNetModel"), ] ) TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping - ("lxmert", "TFLxmertForPreTraining"), - ("t5", "TFT5ForConditionalGeneration"), - ("distilbert", "TFDistilBertForMaskedLM"), ("albert", "TFAlbertForPreTraining"), ("bart", "TFBartForConditionalGeneration"), - ("camembert", "TFCamembertForMaskedLM"), - ("xlm-roberta", "TFXLMRobertaForMaskedLM"), - ("roberta", "TFRobertaForMaskedLM"), - ("layoutlm", "TFLayoutLMForMaskedLM"), ("bert", "TFBertForPreTraining"), - ("openai-gpt", "TFOpenAIGPTLMHeadModel"), - ("gpt2", "TFGPT2LMHeadModel"), - ("mobilebert", "TFMobileBertForPreTraining"), - ("transfo-xl", "TFTransfoXLLMHeadModel"), - ("xlnet", "TFXLNetLMHeadModel"), - ("flaubert", "TFFlaubertWithLMHeadModel"), - ("xlm", "TFXLMWithLMHeadModel"), + ("camembert", "TFCamembertForMaskedLM"), ("ctrl", "TFCTRLLMHeadModel"), + ("distilbert", "TFDistilBertForMaskedLM"), ("electra", "TFElectraForPreTraining"), - ("tapas", "TFTapasForMaskedLM"), + ("flaubert", "TFFlaubertWithLMHeadModel"), ("funnel", "TFFunnelForPreTraining"), + ("gpt2", "TFGPT2LMHeadModel"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("lxmert", "TFLxmertForPreTraining"), + ("mobilebert", "TFMobileBertForPreTraining"), ("mpnet", "TFMPNetForMaskedLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("roberta", "TFRobertaForMaskedLM"), + ("t5", "TFT5ForConditionalGeneration"), + ("tapas", "TFTapasForMaskedLM"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), ("vit_mae", "TFViTMAEForPreTraining"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("xlnet", "TFXLNetLMHeadModel"), ] ) TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping - ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), - ("rembert", "TFRemBertForMaskedLM"), - ("roformer", "TFRoFormerForMaskedLM"), - ("convbert", "TFConvBertForMaskedLM"), - ("led", "TFLEDForConditionalGeneration"), - ("t5", "TFT5ForConditionalGeneration"), - ("distilbert", "TFDistilBertForMaskedLM"), ("albert", "TFAlbertForMaskedLM"), - ("marian", "TFMarianMTModel"), ("bart", "TFBartForConditionalGeneration"), - ("camembert", "TFCamembertForMaskedLM"), - ("xlm-roberta", "TFXLMRobertaForMaskedLM"), - ("longformer", "TFLongformerForMaskedLM"), - ("roberta", "TFRobertaForMaskedLM"), - ("layoutlm", "TFLayoutLMForMaskedLM"), ("bert", "TFBertForMaskedLM"), - ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("camembert", "TFCamembertForMaskedLM"), + ("convbert", "TFConvBertForMaskedLM"), + ("ctrl", "TFCTRLLMHeadModel"), + ("distilbert", "TFDistilBertForMaskedLM"), + ("electra", "TFElectraForMaskedLM"), + ("flaubert", "TFFlaubertWithLMHeadModel"), + ("funnel", "TFFunnelForMaskedLM"), ("gpt2", "TFGPT2LMHeadModel"), ("gptj", "TFGPTJForCausalLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("led", "TFLEDForConditionalGeneration"), + ("longformer", "TFLongformerForMaskedLM"), + ("marian", "TFMarianMTModel"), ("mobilebert", "TFMobileBertForMaskedLM"), - ("transfo-xl", "TFTransfoXLLMHeadModel"), - ("xlnet", "TFXLNetLMHeadModel"), - ("flaubert", "TFFlaubertWithLMHeadModel"), - ("xlm", "TFXLMWithLMHeadModel"), - ("ctrl", "TFCTRLLMHeadModel"), - ("electra", "TFElectraForMaskedLM"), - ("tapas", "TFTapasForMaskedLM"), - ("funnel", "TFFunnelForMaskedLM"), ("mpnet", "TFMPNetForMaskedLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("rembert", "TFRemBertForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("roformer", "TFRoFormerForMaskedLM"), + ("speech_to_text", "TFSpeech2TextForConditionalGeneration"), + ("t5", "TFT5ForConditionalGeneration"), + ("tapas", "TFTapasForMaskedLM"), + ("transfo-xl", "TFTransfoXLLMHeadModel"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), + ("xlnet", "TFXLNetLMHeadModel"), ] ) TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping - ("camembert", "TFCamembertForCausalLM"), - ("rembert", "TFRemBertForCausalLM"), - ("roformer", "TFRoFormerForCausalLM"), - ("roberta", "TFRobertaForCausalLM"), ("bert", "TFBertLMHeadModel"), - ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("camembert", "TFCamembertForCausalLM"), + ("ctrl", "TFCTRLLMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"), ("gptj", "TFGPTJForCausalLM"), + ("openai-gpt", "TFOpenAIGPTLMHeadModel"), + ("rembert", "TFRemBertForCausalLM"), + ("roberta", "TFRobertaForCausalLM"), + ("roformer", "TFRoFormerForCausalLM"), ("transfo-xl", "TFTransfoXLLMHeadModel"), - ("xlnet", "TFXLNetLMHeadModel"), ("xlm", "TFXLMWithLMHeadModel"), - ("ctrl", "TFCTRLLMHeadModel"), + ("xlnet", "TFXLNetLMHeadModel"), ] ) TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Image-classsification - ("vit", "TFViTForImageClassification"), ("convnext", "TFConvNextForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"), + ("vit", "TFViTForImageClassification"), ] ) @@ -177,42 +177,42 @@ TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping - ("deberta-v2", "TFDebertaV2ForMaskedLM"), - ("deberta", "TFDebertaForMaskedLM"), - ("rembert", "TFRemBertForMaskedLM"), - ("roformer", "TFRoFormerForMaskedLM"), - ("convbert", "TFConvBertForMaskedLM"), - ("distilbert", "TFDistilBertForMaskedLM"), ("albert", "TFAlbertForMaskedLM"), - ("camembert", "TFCamembertForMaskedLM"), - ("xlm-roberta", "TFXLMRobertaForMaskedLM"), - ("longformer", "TFLongformerForMaskedLM"), - ("roberta", "TFRobertaForMaskedLM"), - ("layoutlm", "TFLayoutLMForMaskedLM"), ("bert", "TFBertForMaskedLM"), - ("mobilebert", "TFMobileBertForMaskedLM"), - ("flaubert", "TFFlaubertWithLMHeadModel"), - ("xlm", "TFXLMWithLMHeadModel"), + ("camembert", "TFCamembertForMaskedLM"), + ("convbert", "TFConvBertForMaskedLM"), + ("deberta", "TFDebertaForMaskedLM"), + ("deberta-v2", "TFDebertaV2ForMaskedLM"), + ("distilbert", "TFDistilBertForMaskedLM"), ("electra", "TFElectraForMaskedLM"), - ("tapas", "TFTapasForMaskedLM"), + ("flaubert", "TFFlaubertWithLMHeadModel"), ("funnel", "TFFunnelForMaskedLM"), + ("layoutlm", "TFLayoutLMForMaskedLM"), + ("longformer", "TFLongformerForMaskedLM"), + ("mobilebert", "TFMobileBertForMaskedLM"), ("mpnet", "TFMPNetForMaskedLM"), + ("rembert", "TFRemBertForMaskedLM"), + ("roberta", "TFRobertaForMaskedLM"), + ("roformer", "TFRoFormerForMaskedLM"), + ("tapas", "TFTapasForMaskedLM"), + ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForMaskedLM"), ] ) TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping - ("led", "TFLEDForConditionalGeneration"), - ("mt5", "TFMT5ForConditionalGeneration"), - ("t5", "TFT5ForConditionalGeneration"), - ("marian", "TFMarianMTModel"), - ("mbart", "TFMBartForConditionalGeneration"), - ("pegasus", "TFPegasusForConditionalGeneration"), + ("bart", "TFBartForConditionalGeneration"), ("blenderbot", "TFBlenderbotForConditionalGeneration"), ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), - ("bart", "TFBartForConditionalGeneration"), ("encoder-decoder", "TFEncoderDecoderModel"), + ("led", "TFLEDForConditionalGeneration"), + ("marian", "TFMarianMTModel"), + ("mbart", "TFMBartForConditionalGeneration"), + ("mt5", "TFMT5ForConditionalGeneration"), + ("pegasus", "TFPegasusForConditionalGeneration"), + ("t5", "TFT5ForConditionalGeneration"), ] ) @@ -225,58 +225,58 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping - ("deberta-v2", "TFDebertaV2ForSequenceClassification"), - ("deberta", "TFDebertaForSequenceClassification"), - ("rembert", "TFRemBertForSequenceClassification"), - ("roformer", "TFRoFormerForSequenceClassification"), - ("convbert", "TFConvBertForSequenceClassification"), - ("distilbert", "TFDistilBertForSequenceClassification"), ("albert", "TFAlbertForSequenceClassification"), - ("camembert", "TFCamembertForSequenceClassification"), - ("xlm-roberta", "TFXLMRobertaForSequenceClassification"), - ("longformer", "TFLongformerForSequenceClassification"), - ("roberta", "TFRobertaForSequenceClassification"), - ("layoutlm", "TFLayoutLMForSequenceClassification"), ("bert", "TFBertForSequenceClassification"), - ("xlnet", "TFXLNetForSequenceClassification"), - ("mobilebert", "TFMobileBertForSequenceClassification"), - ("flaubert", "TFFlaubertForSequenceClassification"), - ("xlm", "TFXLMForSequenceClassification"), + ("camembert", "TFCamembertForSequenceClassification"), + ("convbert", "TFConvBertForSequenceClassification"), + ("ctrl", "TFCTRLForSequenceClassification"), + ("deberta", "TFDebertaForSequenceClassification"), + ("deberta-v2", "TFDebertaV2ForSequenceClassification"), + ("distilbert", "TFDistilBertForSequenceClassification"), ("electra", "TFElectraForSequenceClassification"), - ("tapas", "TFTapasForSequenceClassification"), + ("flaubert", "TFFlaubertForSequenceClassification"), ("funnel", "TFFunnelForSequenceClassification"), ("gpt2", "TFGPT2ForSequenceClassification"), ("gptj", "TFGPTJForSequenceClassification"), + ("layoutlm", "TFLayoutLMForSequenceClassification"), + ("longformer", "TFLongformerForSequenceClassification"), + ("mobilebert", "TFMobileBertForSequenceClassification"), ("mpnet", "TFMPNetForSequenceClassification"), ("openai-gpt", "TFOpenAIGPTForSequenceClassification"), + ("rembert", "TFRemBertForSequenceClassification"), + ("roberta", "TFRobertaForSequenceClassification"), + ("roformer", "TFRoFormerForSequenceClassification"), + ("tapas", "TFTapasForSequenceClassification"), ("transfo-xl", "TFTransfoXLForSequenceClassification"), - ("ctrl", "TFCTRLForSequenceClassification"), + ("xlm", "TFXLMForSequenceClassification"), + ("xlm-roberta", "TFXLMRobertaForSequenceClassification"), + ("xlnet", "TFXLNetForSequenceClassification"), ] ) TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping - ("deberta-v2", "TFDebertaV2ForQuestionAnswering"), - ("deberta", "TFDebertaForQuestionAnswering"), - ("rembert", "TFRemBertForQuestionAnswering"), - ("roformer", "TFRoFormerForQuestionAnswering"), - ("convbert", "TFConvBertForQuestionAnswering"), - ("distilbert", "TFDistilBertForQuestionAnswering"), ("albert", "TFAlbertForQuestionAnswering"), - ("camembert", "TFCamembertForQuestionAnswering"), - ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"), - ("longformer", "TFLongformerForQuestionAnswering"), - ("roberta", "TFRobertaForQuestionAnswering"), ("bert", "TFBertForQuestionAnswering"), - ("xlnet", "TFXLNetForQuestionAnsweringSimple"), - ("mobilebert", "TFMobileBertForQuestionAnswering"), - ("flaubert", "TFFlaubertForQuestionAnsweringSimple"), - ("xlm", "TFXLMForQuestionAnsweringSimple"), + ("camembert", "TFCamembertForQuestionAnswering"), + ("convbert", "TFConvBertForQuestionAnswering"), + ("deberta", "TFDebertaForQuestionAnswering"), + ("deberta-v2", "TFDebertaV2ForQuestionAnswering"), + ("distilbert", "TFDistilBertForQuestionAnswering"), ("electra", "TFElectraForQuestionAnswering"), + ("flaubert", "TFFlaubertForQuestionAnsweringSimple"), ("funnel", "TFFunnelForQuestionAnswering"), ("gptj", "TFGPTJForQuestionAnswering"), + ("longformer", "TFLongformerForQuestionAnswering"), + ("mobilebert", "TFMobileBertForQuestionAnswering"), ("mpnet", "TFMPNetForQuestionAnswering"), + ("rembert", "TFRemBertForQuestionAnswering"), + ("roberta", "TFRobertaForQuestionAnswering"), + ("roformer", "TFRoFormerForQuestionAnswering"), + ("xlm", "TFXLMForQuestionAnsweringSimple"), + ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"), + ("xlnet", "TFXLNetForQuestionAnsweringSimple"), ] ) @@ -291,49 +291,49 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping - ("deberta-v2", "TFDebertaV2ForTokenClassification"), - ("deberta", "TFDebertaForTokenClassification"), - ("rembert", "TFRemBertForTokenClassification"), - ("roformer", "TFRoFormerForTokenClassification"), - ("convbert", "TFConvBertForTokenClassification"), - ("distilbert", "TFDistilBertForTokenClassification"), ("albert", "TFAlbertForTokenClassification"), + ("bert", "TFBertForTokenClassification"), ("camembert", "TFCamembertForTokenClassification"), + ("convbert", "TFConvBertForTokenClassification"), + ("deberta", "TFDebertaForTokenClassification"), + ("deberta-v2", "TFDebertaV2ForTokenClassification"), + ("distilbert", "TFDistilBertForTokenClassification"), + ("electra", "TFElectraForTokenClassification"), ("flaubert", "TFFlaubertForTokenClassification"), + ("funnel", "TFFunnelForTokenClassification"), + ("layoutlm", "TFLayoutLMForTokenClassification"), + ("longformer", "TFLongformerForTokenClassification"), + ("mobilebert", "TFMobileBertForTokenClassification"), + ("mpnet", "TFMPNetForTokenClassification"), + ("rembert", "TFRemBertForTokenClassification"), + ("roberta", "TFRobertaForTokenClassification"), + ("roformer", "TFRoFormerForTokenClassification"), ("xlm", "TFXLMForTokenClassification"), ("xlm-roberta", "TFXLMRobertaForTokenClassification"), - ("longformer", "TFLongformerForTokenClassification"), - ("roberta", "TFRobertaForTokenClassification"), - ("layoutlm", "TFLayoutLMForTokenClassification"), - ("bert", "TFBertForTokenClassification"), - ("mobilebert", "TFMobileBertForTokenClassification"), ("xlnet", "TFXLNetForTokenClassification"), - ("electra", "TFElectraForTokenClassification"), - ("funnel", "TFFunnelForTokenClassification"), - ("mpnet", "TFMPNetForTokenClassification"), ] ) TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping - ("rembert", "TFRemBertForMultipleChoice"), - ("roformer", "TFRoFormerForMultipleChoice"), - ("convbert", "TFConvBertForMultipleChoice"), + ("albert", "TFAlbertForMultipleChoice"), + ("bert", "TFBertForMultipleChoice"), ("camembert", "TFCamembertForMultipleChoice"), + ("convbert", "TFConvBertForMultipleChoice"), + ("distilbert", "TFDistilBertForMultipleChoice"), + ("electra", "TFElectraForMultipleChoice"), + ("flaubert", "TFFlaubertForMultipleChoice"), + ("funnel", "TFFunnelForMultipleChoice"), + ("longformer", "TFLongformerForMultipleChoice"), + ("mobilebert", "TFMobileBertForMultipleChoice"), + ("mpnet", "TFMPNetForMultipleChoice"), + ("rembert", "TFRemBertForMultipleChoice"), + ("roberta", "TFRobertaForMultipleChoice"), + ("roformer", "TFRoFormerForMultipleChoice"), ("xlm", "TFXLMForMultipleChoice"), ("xlm-roberta", "TFXLMRobertaForMultipleChoice"), - ("longformer", "TFLongformerForMultipleChoice"), - ("roberta", "TFRobertaForMultipleChoice"), - ("bert", "TFBertForMultipleChoice"), - ("distilbert", "TFDistilBertForMultipleChoice"), - ("mobilebert", "TFMobileBertForMultipleChoice"), ("xlnet", "TFXLNetForMultipleChoice"), - ("flaubert", "TFFlaubertForMultipleChoice"), - ("albert", "TFAlbertForMultipleChoice"), - ("electra", "TFElectraForMultipleChoice"), - ("funnel", "TFFunnelForMultipleChoice"), - ("mpnet", "TFMPNetForMultipleChoice"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 32b8cce1fe9..5f7896b1289 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -41,17 +41,17 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("flava", "FLAVAProcessor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutxlm", "LayoutXLMProcessor"), + ("sew", "Wav2Vec2Processor"), + ("sew-d", "Wav2Vec2Processor"), ("speech_to_text", "Speech2TextProcessor"), ("speech_to_text_2", "Speech2Text2Processor"), ("trocr", "TrOCRProcessor"), - ("wav2vec2", "Wav2Vec2Processor"), - ("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"), - ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), ("unispeech", "Wav2Vec2Processor"), ("unispeech-sat", "Wav2Vec2Processor"), - ("sew", "Wav2Vec2Processor"), - ("sew-d", "Wav2Vec2Processor"), ("vilt", "ViltProcessor"), + ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), + ("wav2vec2", "Wav2Vec2Processor"), + ("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"), ("wavlm", "Wav2Vec2Processor"), ] ) @@ -65,7 +65,10 @@ def processor_class_from_name(class_name: str): module_name = model_type_to_module_name(module_name) module = importlib.import_module(f".{module_name}", "transformers.models") - return getattr(module, class_name) + try: + return getattr(module, class_name) + except AttributeError: + continue for processor in PROCESSOR_MAPPING._extra_content.values(): if getattr(processor, "__name__", None) == class_name: diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 4ec7d96ebca..1e7033e1954 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -46,27 +46,6 @@ if TYPE_CHECKING: else: TOKENIZER_MAPPING_NAMES = OrderedDict( [ - ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), - ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), - ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), - ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), - ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), - ( - "t5", - ( - "T5Tokenizer" if is_sentencepiece_available() else None, - "T5TokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "mt5", - ( - "MT5Tokenizer" if is_sentencepiece_available() else None, - "MT5TokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), - ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), ( "albert", ( @@ -74,6 +53,30 @@ else: "AlbertTokenizerFast" if is_tokenizers_available() else None, ), ), + ("bart", ("BartTokenizer", "BartTokenizerFast")), + ( + "barthez", + ( + "BarthezTokenizer" if is_sentencepiece_available() else None, + "BarthezTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("bartpho", ("BartphoTokenizer", None)), + ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)), + ("bert-japanese", ("BertJapaneseTokenizer", None)), + ("bertweet", ("BertweetTokenizer", None)), + ( + "big_bird", + ( + "BigBirdTokenizer" if is_sentencepiece_available() else None, + "BigBirdTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), + ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")), + ("blenderbot-small", ("BlenderbotSmallTokenizer", None)), + ("byt5", ("ByT5Tokenizer", None)), ( "camembert", ( @@ -81,76 +84,24 @@ else: "CamembertTokenizerFast" if is_tokenizers_available() else None, ), ), + ("canine", ("CanineTokenizer", None)), ( - "pegasus", + "clip", ( - "PegasusTokenizer" if is_sentencepiece_available() else None, - "PegasusTokenizerFast" if is_tokenizers_available() else None, + "CLIPTokenizer", + "CLIPTokenizerFast" if is_tokenizers_available() else None, ), ), + ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), ( - "mbart", + "cpm", ( - "MBartTokenizer" if is_sentencepiece_available() else None, - "MBartTokenizerFast" if is_tokenizers_available() else None, + "CpmTokenizer" if is_sentencepiece_available() else None, + "CpmTokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "xlm-roberta", - ( - "XLMRobertaTokenizer" if is_sentencepiece_available() else None, - "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), - ("blenderbot-small", ("BlenderbotSmallTokenizer", None)), - ("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")), - ("tapex", ("TapexTokenizer", None)), - ("bart", ("BartTokenizer", "BartTokenizerFast")), - ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), - ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), - ( - "reformer", - ( - "ReformerTokenizer" if is_sentencepiece_available() else None, - "ReformerTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), - ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), - ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), - ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), - ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), - ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), - ( - "dpr", - ( - "DPRQuestionEncoderTokenizer", - "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "squeezebert", - ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), - ), - ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), - ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), - ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), - ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), - ("opt", ("GPT2Tokenizer", None)), - ("transfo-xl", ("TransfoXLTokenizer", None)), - ( - "xlnet", - ( - "XLNetTokenizer" if is_sentencepiece_available() else None, - "XLNetTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("flaubert", ("FlaubertTokenizer", None)), - ("xlm", ("XLMTokenizer", None)), ("ctrl", ("CTRLTokenizer", None)), - ("fsmt", ("FSMTTokenizer", None)), - ("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)), + ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)), ( "deberta-v2", @@ -159,51 +110,39 @@ else: "DebertaV2TokenizerFast" if is_tokenizers_available() else None, ), ), - ("rag", ("RagTokenizer", None)), - ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)), - ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), - ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), - ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), - ("prophetnet", ("ProphetNetTokenizer", None)), - ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), - ("tapas", ("TapasTokenizer", None)), - ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), - ("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)), + ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), ( - "big_bird", + "dpr", ( - "BigBirdTokenizer" if is_sentencepiece_available() else None, - "BigBirdTokenizerFast" if is_tokenizers_available() else None, + "DPRQuestionEncoderTokenizer", + "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None, ), ), - ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), - ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), - ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), - ("hubert", ("Wav2Vec2CTCTokenizer", None)), + ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), + ("flaubert", ("FlaubertTokenizer", None)), + ("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)), + ("fsmt", ("FSMTTokenizer", None)), + ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), + ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), - ("luke", ("LukeTokenizer", None)), - ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), - ("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)), - ("canine", ("CanineTokenizer", None)), - ("bertweet", ("BertweetTokenizer", None)), - ("bert-japanese", ("BertJapaneseTokenizer", None)), - ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")), - ("byt5", ("ByT5Tokenizer", None)), - ( - "cpm", - ( - "CpmTokenizer" if is_sentencepiece_available() else None, - "CpmTokenizerFast" if is_tokenizers_available() else None, - ), - ), + ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), - ("phobert", ("PhobertTokenizer", None)), - ("bartpho", ("BartphoTokenizer", None)), + ("hubert", ("Wav2Vec2CTCTokenizer", None)), + ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), + ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), + ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), + ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), + ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), + ("luke", ("LukeTokenizer", None)), + ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), + ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), + ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), ( - "barthez", + "mbart", ( - "BarthezTokenizer" if is_sentencepiece_available() else None, - "BarthezTokenizerFast" if is_tokenizers_available() else None, + "MBartTokenizer" if is_sentencepiece_available() else None, + "MBartTokenizerFast" if is_tokenizers_available() else None, ), ), ( @@ -213,37 +152,17 @@ else: "MBart50TokenizerFast" if is_tokenizers_available() else None, ), ), - ( - "rembert", - ( - "RemBertTokenizer" if is_sentencepiece_available() else None, - "RemBertTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ( - "clip", - ( - "CLIPTokenizer", - "CLIPTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)), - ( - "perceiver", - ( - "PerceiverTokenizer", - None, - ), - ), - ( - "xglm", - ( - "XGLMTokenizer" if is_sentencepiece_available() else None, - "XGLMTokenizerFast" if is_tokenizers_available() else None, - ), - ), - ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)), + ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), + ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)), + ( + "mt5", + ( + "MT5Tokenizer" if is_sentencepiece_available() else None, + "MT5TokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "nystromformer", ( @@ -251,7 +170,89 @@ else: "AlbertTokenizerFast" if is_tokenizers_available() else None, ), ), + ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), + ("opt", ("GPT2Tokenizer", None)), + ( + "pegasus", + ( + "PegasusTokenizer" if is_sentencepiece_available() else None, + "PegasusTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "perceiver", + ( + "PerceiverTokenizer", + None, + ), + ), + ("phobert", ("PhobertTokenizer", None)), + ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), + ("prophetnet", ("ProphetNetTokenizer", None)), + ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("rag", ("RagTokenizer", None)), + ("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)), + ( + "reformer", + ( + "ReformerTokenizer" if is_sentencepiece_available() else None, + "ReformerTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "rembert", + ( + "RemBertTokenizer" if is_sentencepiece_available() else None, + "RemBertTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)), + ("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), + ("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)), + ("speech_to_text_2", ("Speech2Text2Tokenizer", None)), + ("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")), + ( + "squeezebert", + ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None), + ), + ( + "t5", + ( + "T5Tokenizer" if is_sentencepiece_available() else None, + "T5TokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("tapas", ("TapasTokenizer", None)), + ("tapex", ("TapexTokenizer", None)), + ("transfo-xl", ("TransfoXLTokenizer", None)), + ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("wav2vec2", ("Wav2Vec2CTCTokenizer", None)), + ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)), + ( + "xglm", + ( + "XGLMTokenizer" if is_sentencepiece_available() else None, + "XGLMTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ("xlm", ("XLMTokenizer", None)), + ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)), + ( + "xlm-roberta", + ( + "XLMRobertaTokenizer" if is_sentencepiece_available() else None, + "XLMRobertaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), + ( + "xlnet", + ( + "XLNetTokenizer" if is_sentencepiece_available() else None, + "XLNetTokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "yoso", ( @@ -259,7 +260,6 @@ else: "AlbertTokenizerFast" if is_tokenizers_available() else None, ), ), - ("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ] ) @@ -277,7 +277,10 @@ def tokenizer_class_from_name(class_name: str): module_name = model_type_to_module_name(module_name) module = importlib.import_module(f".{module_name}", "transformers.models") - return getattr(module, class_name) + try: + return getattr(module, class_name) + except AttributeError: + continue for config, tokenizers in TOKENIZER_MAPPING._extra_content.items(): for tokenizer in tokenizers: diff --git a/tests/models/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py index eeb10ad2d31..2695082c412 100644 --- a/tests/models/auto/test_configuration_auto.py +++ b/tests/models/auto/test_configuration_auto.py @@ -14,6 +14,8 @@ # limitations under the License. import importlib +import json +import os import sys import tempfile import unittest @@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase): self.assertIsInstance(config, RobertaConfig) def test_pattern_matching_fallback(self): - """ - In cases where config.json doesn't include a model_type, - perform a few safety checks on the config mapping's order. - """ - # no key string should be included in a later key string (typical failure case) - keys = list(CONFIG_MAPPING.keys()) - for i, key in enumerate(keys): - self.assertFalse(any(key in later_key for later_key in keys[i + 1 :])) + with tempfile.TemporaryDirectory() as tmp_dir: + # This model name contains bert and roberta, but roberta ends up being picked. + folder = os.path.join(tmp_dir, "fake-roberta") + os.makedirs(folder, exist_ok=True) + with open(os.path.join(folder, "config.json"), "w") as f: + f.write(json.dumps({})) + config = AutoConfig.from_pretrained(folder) + self.assertEqual(type(config), RobertaConfig) def test_new_config_registration(self): try: diff --git a/utils/sort_auto_mappings.py b/utils/sort_auto_mappings.py new file mode 100644 index 00000000000..ef985dc43cd --- /dev/null +++ b/utils/sort_auto_mappings.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re + + +PATH_TO_AUTO_MODULE = "src/transformers/models/auto" + + +# re pattern that matches mapping introductions: +# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict +_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict") +# re pattern that matches identifiers in mappings +_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"') + + +def sort_auto_mapping(fname, overwrite: bool = False): + with open(fname, "r", encoding="utf-8") as f: + content = f.read() + + lines = content.split("\n") + new_lines = [] + line_idx = 0 + while line_idx < len(lines): + if _re_intro_mapping.search(lines[line_idx]) is not None: + indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8 + # Start of a new mapping! + while not lines[line_idx].startswith(" " * indent + "("): + new_lines.append(lines[line_idx]) + line_idx += 1 + + blocks = [] + while lines[line_idx].strip() != "]": + # Blocks either fit in one line or not + if lines[line_idx].strip() == "(": + start_idx = line_idx + while not lines[line_idx].startswith(" " * indent + ")"): + line_idx += 1 + blocks.append("\n".join(lines[start_idx : line_idx + 1])) + else: + blocks.append(lines[line_idx]) + line_idx += 1 + + # Sort blocks by their identifiers + blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0]) + new_lines += blocks + else: + new_lines.append(lines[line_idx]) + line_idx += 1 + + if overwrite: + with open(fname, "w", encoding="utf-8") as f: + f.write("\n".join(new_lines)) + elif "\n".join(new_lines) != content: + return True + + +def sort_all_auto_mappings(overwrite: bool = False): + fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")] + diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames] + + if not overwrite and any(diffs): + failures = [f for f, d in zip(fnames, diffs) if d] + raise ValueError( + f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix" + " this." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") + args = parser.parse_args() + + sort_all_auto_mappings(not args.check_only)