Automatically sort auto mappings (#17250)

* Automatically sort auto mappings

* Better class extraction

* Some auto class magic

* Adapt test and underlying behavior

* Remove re-used config

* Quality
This commit is contained in:
Sylvain Gugger 2022-05-16 13:24:20 -04:00 committed by GitHub
parent 2f611f85e2
commit ddb1a47ec8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1199 additions and 1094 deletions

View File

@ -857,6 +857,7 @@ jobs:
- run: black --check --preview examples tests src utils
- run: isort --check-only examples tests src utils
- run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only
- run: flake8 examples tests src utils
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source

View File

@ -48,6 +48,7 @@ quality:
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs)
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
@ -55,6 +56,7 @@ quality:
extra_style_checks:
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
# this target runs checks on all files and potentially modifies some of them

View File

@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow.
| Swin | ❌ | ❌ | ✅ | ❌ | ❌ |
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| TAPEX | ✅ | ✅ | ✅ | ✅ | ✅ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |

View File

@ -74,7 +74,6 @@ Ready-made configurations include the following architectures:
- RoBERTa
- RoFormer
- T5
- TAPEX
- ViT
- XLM-RoBERTa
- XLM-RoBERTa-XL

View File

@ -560,10 +560,17 @@ class _LazyAutoMapping(OrderedDict):
if key in self._extra_content:
return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
if model_type in self._model_mapping:
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
# Maybe there was several model types associated with this config.
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping:
model_name = self._model_mapping[mtype]
return self._load_attr_from_module(mtype, model_name)
raise KeyError(key)
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)

View File

@ -29,339 +29,338 @@ logger = logging.get_logger(__name__)
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("yolos", "YolosConfig"),
("tapex", "BartConfig"),
("dpt", "DPTConfig"),
("decision_transformer", "DecisionTransformerConfig"),
("glpn", "GLPNConfig"),
("maskformer", "MaskFormerConfig"),
("decision_transformer", "DecisionTransformerConfig"),
("poolformer", "PoolFormerConfig"),
("convnext", "ConvNextConfig"),
("van", "VanConfig"),
("resnet", "ResNetConfig"),
("regnet", "RegNetConfig"),
("yoso", "YosoConfig"),
("swin", "SwinConfig"),
("vilt", "ViltConfig"),
("vit_mae", "ViTMAEConfig"),
("realm", "RealmConfig"),
("nystromformer", "NystromformerConfig"),
("xglm", "XGLMConfig"),
("imagegpt", "ImageGPTConfig"),
("qdqbert", "QDQBertConfig"),
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
("trocr", "TrOCRConfig"),
("fnet", "FNetConfig"),
("segformer", "SegformerConfig"),
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
("perceiver", "PerceiverConfig"),
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("plbart", "PLBartConfig"),
("beit", "BeitConfig"),
("data2vec-vision", "Data2VecVisionConfig"),
("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"),
("canine", "CanineConfig"),
("roformer", "RoFormerConfig"),
("clip", "CLIPConfig"),
("flava", "FlavaConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("deit", "DeiTConfig"),
("luke", "LukeConfig"),
("detr", "DetrConfig"),
("gpt_neo", "GPTNeoConfig"),
("big_bird", "BigBirdConfig"),
("speech_to_text_2", "Speech2Text2Config"),
("speech_to_text", "Speech2TextConfig"),
("vit", "ViTConfig"),
("wav2vec2", "Wav2Vec2Config"),
("m2m_100", "M2M100Config"),
("convbert", "ConvBertConfig"),
("led", "LEDConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
("retribert", "RetriBertConfig"),
("ibert", "IBertConfig"),
("mt5", "MT5Config"),
("t5", "T5Config"),
("mobilebert", "MobileBertConfig"),
("distilbert", "DistilBertConfig"),
("albert", "AlbertConfig"),
("bert-generation", "BertGenerationConfig"),
("camembert", "CamembertConfig"),
("xlm-roberta-xl", "XLMRobertaXLConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("pegasus", "PegasusConfig"),
("marian", "MarianConfig"),
("mbart", "MBartConfig"),
("megatron-bert", "MegatronBertConfig"),
("mpnet", "MPNetConfig"),
("bart", "BartConfig"),
("opt", "OPTConfig"),
("blenderbot", "BlenderbotConfig"),
("reformer", "ReformerConfig"),
("longformer", "LongformerConfig"),
("roberta", "RobertaConfig"),
("deberta-v2", "DebertaV2Config"),
("deberta", "DebertaConfig"),
("flaubert", "FlaubertConfig"),
("fsmt", "FSMTConfig"),
("squeezebert", "SqueezeBertConfig"),
("hubert", "HubertConfig"),
("beit", "BeitConfig"),
("bert", "BertConfig"),
("openai-gpt", "OpenAIGPTConfig"),
("gpt2", "GPT2Config"),
("transfo-xl", "TransfoXLConfig"),
("xlnet", "XLNetConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
("prophetnet", "ProphetNetConfig"),
("xlm", "XLMConfig"),
("bert-generation", "BertGenerationConfig"),
("big_bird", "BigBirdConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("blenderbot", "BlenderbotConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
("camembert", "CamembertConfig"),
("canine", "CanineConfig"),
("clip", "CLIPConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
("ctrl", "CTRLConfig"),
("electra", "ElectraConfig"),
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
("encoder-decoder", "EncoderDecoderConfig"),
("funnel", "FunnelConfig"),
("lxmert", "LxmertConfig"),
("dpr", "DPRConfig"),
("layoutlm", "LayoutLMConfig"),
("rag", "RagConfig"),
("tapas", "TapasConfig"),
("splinter", "SplinterConfig"),
("sew-d", "SEWDConfig"),
("sew", "SEWConfig"),
("unispeech-sat", "UniSpeechSatConfig"),
("unispeech", "UniSpeechConfig"),
("wavlm", "WavLMConfig"),
("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"),
("data2vec-vision", "Data2VecVisionConfig"),
("deberta", "DebertaConfig"),
("deberta-v2", "DebertaV2Config"),
("decision_transformer", "DecisionTransformerConfig"),
("decision_transformer", "DecisionTransformerConfig"),
("deit", "DeiTConfig"),
("detr", "DetrConfig"),
("distilbert", "DistilBertConfig"),
("dpr", "DPRConfig"),
("dpt", "DPTConfig"),
("electra", "ElectraConfig"),
("encoder-decoder", "EncoderDecoderConfig"),
("flaubert", "FlaubertConfig"),
("flava", "FlavaConfig"),
("fnet", "FNetConfig"),
("fsmt", "FSMTConfig"),
("funnel", "FunnelConfig"),
("glpn", "GLPNConfig"),
("gpt2", "GPT2Config"),
("gpt_neo", "GPTNeoConfig"),
("gptj", "GPTJConfig"),
("hubert", "HubertConfig"),
("ibert", "IBertConfig"),
("imagegpt", "ImageGPTConfig"),
("layoutlm", "LayoutLMConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("led", "LEDConfig"),
("longformer", "LongformerConfig"),
("luke", "LukeConfig"),
("lxmert", "LxmertConfig"),
("m2m_100", "M2M100Config"),
("marian", "MarianConfig"),
("maskformer", "MaskFormerConfig"),
("mbart", "MBartConfig"),
("megatron-bert", "MegatronBertConfig"),
("mobilebert", "MobileBertConfig"),
("mpnet", "MPNetConfig"),
("mt5", "MT5Config"),
("nystromformer", "NystromformerConfig"),
("openai-gpt", "OpenAIGPTConfig"),
("opt", "OPTConfig"),
("pegasus", "PegasusConfig"),
("perceiver", "PerceiverConfig"),
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("prophetnet", "ProphetNetConfig"),
("qdqbert", "QDQBertConfig"),
("rag", "RagConfig"),
("realm", "RealmConfig"),
("reformer", "ReformerConfig"),
("regnet", "RegNetConfig"),
("rembert", "RemBertConfig"),
("resnet", "ResNetConfig"),
("retribert", "RetriBertConfig"),
("roberta", "RobertaConfig"),
("roformer", "RoFormerConfig"),
("segformer", "SegformerConfig"),
("sew", "SEWConfig"),
("sew-d", "SEWDConfig"),
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
("speech_to_text", "Speech2TextConfig"),
("speech_to_text_2", "Speech2Text2Config"),
("splinter", "SplinterConfig"),
("squeezebert", "SqueezeBertConfig"),
("swin", "SwinConfig"),
("t5", "T5Config"),
("tapas", "TapasConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
("unispeech", "UniSpeechConfig"),
("unispeech-sat", "UniSpeechSatConfig"),
("van", "VanConfig"),
("vilt", "ViltConfig"),
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
("visual_bert", "VisualBertConfig"),
("vit", "ViTConfig"),
("vit_mae", "ViTMAEConfig"),
("wav2vec2", "Wav2Vec2Config"),
("wavlm", "WavLMConfig"),
("xglm", "XGLMConfig"),
("xlm", "XLMConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("xlm-roberta-xl", "XLMRobertaXLConfig"),
("xlnet", "XLNetConfig"),
("yolos", "YolosConfig"),
("yoso", "YosoConfig"),
]
)
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here)
("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
]
)
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("yolos", "YOLOS"),
("tapex", "TAPEX"),
("dpt", "DPT"),
("decision_transformer", "Decision Transformer"),
("glpn", "GLPN"),
("maskformer", "MaskFormer"),
("poolformer", "PoolFormer"),
("convnext", "ConvNext"),
("van", "VAN"),
("resnet", "ResNet"),
("regnet", "RegNet"),
("yoso", "YOSO"),
("swin", "Swin"),
("vilt", "ViLT"),
("vit_mae", "ViTMAE"),
("realm", "Realm"),
("nystromformer", "Nystromformer"),
("xglm", "XGLM"),
("imagegpt", "ImageGPT"),
("qdqbert", "QDQBert"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("trocr", "TrOCR"),
("fnet", "FNet"),
("segformer", "SegFormer"),
("vision-text-dual-encoder", "VisionTextDualEncoder"),
("perceiver", "Perceiver"),
("gptj", "GPT-J"),
("beit", "BEiT"),
("plbart", "PLBart"),
("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"),
("visual_bert", "VisualBert"),
("canine", "Canine"),
("roformer", "RoFormer"),
("clip", "CLIP"),
("flava", "Flava"),
("bigbird_pegasus", "BigBirdPegasus"),
("deit", "DeiT"),
("luke", "LUKE"),
("detr", "DETR"),
("gpt_neo", "GPT Neo"),
("big_bird", "BigBird"),
("speech_to_text_2", "Speech2Text2"),
("speech_to_text", "Speech2Text"),
("vit", "ViT"),
("wav2vec2", "Wav2Vec2"),
("m2m_100", "M2M100"),
("convbert", "ConvBERT"),
("led", "LED"),
("blenderbot-small", "BlenderbotSmall"),
("retribert", "RetriBERT"),
("ibert", "I-BERT"),
("t5", "T5"),
("mobilebert", "MobileBERT"),
("distilbert", "DistilBERT"),
("albert", "ALBERT"),
("bert-generation", "Bert Generation"),
("camembert", "CamemBERT"),
("xlm-roberta", "XLM-RoBERTa"),
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
("pegasus", "Pegasus"),
("blenderbot", "Blenderbot"),
("marian", "Marian"),
("mbart", "mBART"),
("megatron-bert", "MegatronBert"),
("bart", "BART"),
("opt", "OPT"),
("reformer", "Reformer"),
("longformer", "Longformer"),
("roberta", "RoBERTa"),
("flaubert", "FlauBERT"),
("fsmt", "FairSeq Machine-Translation"),
("squeezebert", "SqueezeBERT"),
("bert", "BERT"),
("openai-gpt", "OpenAI GPT"),
("gpt2", "OpenAI GPT-2"),
("transfo-xl", "Transformer-XL"),
("xlnet", "XLNet"),
("xlm", "XLM"),
("ctrl", "CTRL"),
("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"),
("speech-encoder-decoder", "Speech Encoder decoder"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("funnel", "Funnel Transformer"),
("lxmert", "LXMERT"),
("deberta-v2", "DeBERTa-v2"),
("deberta", "DeBERTa"),
("layoutlm", "LayoutLM"),
("dpr", "DPR"),
("rag", "RAG"),
("xlm-prophetnet", "XLMProphetNet"),
("prophetnet", "ProphetNet"),
("mt5", "mT5"),
("mpnet", "MPNet"),
("tapas", "TAPAS"),
("hubert", "Hubert"),
("barthez", "BARThez"),
("phobert", "PhoBERT"),
("bartpho", "BARTpho"),
("cpm", "CPM"),
("bertweet", "Bertweet"),
("beit", "BEiT"),
("bert", "BERT"),
("bert-generation", "Bert Generation"),
("bert-japanese", "BertJapanese"),
("byt5", "ByT5"),
("mbart50", "mBART-50"),
("splinter", "Splinter"),
("sew-d", "SEW-D"),
("sew", "SEW"),
("unispeech-sat", "UniSpeechSat"),
("unispeech", "UniSpeech"),
("wavlm", "WavLM"),
("bertweet", "Bertweet"),
("big_bird", "BigBird"),
("bigbird_pegasus", "BigBirdPegasus"),
("blenderbot", "Blenderbot"),
("blenderbot-small", "BlenderbotSmall"),
("bort", "BORT"),
("dialogpt", "DialoGPT"),
("xls_r", "XLS-R"),
("t5v1.1", "T5v1.1"),
("herbert", "HerBERT"),
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
("megatron_gpt2", "MegatronGPT2"),
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
("mluke", "mLUKE"),
("layoutxlm", "LayoutXLM"),
("byt5", "ByT5"),
("camembert", "CamemBERT"),
("canine", "Canine"),
("clip", "CLIP"),
("convbert", "ConvBERT"),
("convnext", "ConvNext"),
("cpm", "CPM"),
("ctrl", "CTRL"),
("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"),
("data2vec-vision", "Data2VecVision"),
("deberta", "DeBERTa"),
("deberta-v2", "DeBERTa-v2"),
("decision_transformer", "Decision Transformer"),
("deit", "DeiT"),
("detr", "DETR"),
("dialogpt", "DialoGPT"),
("distilbert", "DistilBERT"),
("dit", "DiT"),
("dpr", "DPR"),
("dpt", "DPT"),
("electra", "ELECTRA"),
("encoder-decoder", "Encoder decoder"),
("flaubert", "FlauBERT"),
("flava", "Flava"),
("fnet", "FNet"),
("fsmt", "FairSeq Machine-Translation"),
("funnel", "Funnel Transformer"),
("glpn", "GLPN"),
("gpt2", "OpenAI GPT-2"),
("gpt_neo", "GPT Neo"),
("gptj", "GPT-J"),
("herbert", "HerBERT"),
("hubert", "Hubert"),
("ibert", "I-BERT"),
("imagegpt", "ImageGPT"),
("layoutlm", "LayoutLM"),
("layoutlmv2", "LayoutLMv2"),
("layoutxlm", "LayoutXLM"),
("led", "LED"),
("longformer", "Longformer"),
("luke", "LUKE"),
("lxmert", "LXMERT"),
("m2m_100", "M2M100"),
("marian", "Marian"),
("maskformer", "MaskFormer"),
("mbart", "mBART"),
("mbart50", "mBART-50"),
("megatron-bert", "MegatronBert"),
("megatron_gpt2", "MegatronGPT2"),
("mluke", "mLUKE"),
("mobilebert", "MobileBERT"),
("mpnet", "MPNet"),
("mt5", "mT5"),
("nystromformer", "Nystromformer"),
("openai-gpt", "OpenAI GPT"),
("opt", "OPT"),
("pegasus", "Pegasus"),
("perceiver", "Perceiver"),
("phobert", "PhoBERT"),
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("prophetnet", "ProphetNet"),
("qdqbert", "QDQBert"),
("rag", "RAG"),
("realm", "Realm"),
("reformer", "Reformer"),
("regnet", "RegNet"),
("rembert", "RemBERT"),
("resnet", "ResNet"),
("retribert", "RetriBERT"),
("roberta", "RoBERTa"),
("roformer", "RoFormer"),
("segformer", "SegFormer"),
("sew", "SEW"),
("sew-d", "SEW-D"),
("speech-encoder-decoder", "Speech Encoder decoder"),
("speech_to_text", "Speech2Text"),
("speech_to_text_2", "Speech2Text2"),
("splinter", "Splinter"),
("squeezebert", "SqueezeBERT"),
("swin", "Swin"),
("t5", "T5"),
("t5v1.1", "T5v1.1"),
("tapas", "TAPAS"),
("tapex", "TAPEX"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),
("unispeech", "UniSpeech"),
("unispeech-sat", "UniSpeechSat"),
("van", "VAN"),
("vilt", "ViLT"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("vision-encoder-decoder", "Vision Encoder decoder"),
("vision-text-dual-encoder", "VisionTextDualEncoder"),
("visual_bert", "VisualBert"),
("vit", "ViT"),
("vit_mae", "ViTMAE"),
("wav2vec2", "Wav2Vec2"),
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
("wavlm", "WavLM"),
("xglm", "XGLM"),
("xlm", "XLM"),
("xlm-prophetnet", "XLMProphetNet"),
("xlm-roberta", "XLM-RoBERTa"),
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
("xlnet", "XLNet"),
("xls_r", "XLS-R"),
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
("yolos", "YOLOS"),
("yoso", "YOSO"),
]
)
@ -703,9 +702,10 @@ class AutoConfig:
return config_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, config_class in CONFIG_MAPPING.items():
# We go from longer names to shorter names to catch roberta before bert (for instance)
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
if pattern in str(pretrained_model_name_or_path):
return config_class.from_dict(config_dict, **kwargs)
return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs)
raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. "

View File

@ -38,30 +38,30 @@ logger = logging.get_logger(__name__)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "CLIPFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("dpt", "DPTFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("yolos", "YolosFeatureExtractor"),
]
)
@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
break
try:
return getattr(module, class_name)
except AttributeError:
continue
for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
if getattr(extractor, "__name__", None) == class_name:

View File

@ -28,257 +28,257 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("yolos", "YolosModel"),
("dpt", "DPTModel"),
("decision_transformer", "DecisionTransformerModel"),
("glpn", "GLPNModel"),
("maskformer", "MaskFormerModel"),
("decision_transformer", "DecisionTransformerModel"),
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
("poolformer", "PoolFormerModel"),
("convnext", "ConvNextModel"),
("van", "VanModel"),
("resnet", "ResNetModel"),
("regnet", "RegNetModel"),
("yoso", "YosoModel"),
("swin", "SwinModel"),
("vilt", "ViltModel"),
("vit_mae", "ViTMAEModel"),
("nystromformer", "NystromformerModel"),
("xglm", "XGLMModel"),
("imagegpt", "ImageGPTModel"),
("qdqbert", "QDQBertModel"),
("fnet", "FNetModel"),
("segformer", "SegformerModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("perceiver", "PerceiverModel"),
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("plbart", "PLBartModel"),
("beit", "BeitModel"),
("data2vec-vision", "Data2VecVisionModel"),
("rembert", "RemBertModel"),
("visual_bert", "VisualBertModel"),
("canine", "CanineModel"),
("roformer", "RoFormerModel"),
("clip", "CLIPModel"),
("flava", "FlavaModel"),
("bigbird_pegasus", "BigBirdPegasusModel"),
("deit", "DeiTModel"),
("luke", "LukeModel"),
("detr", "DetrModel"),
("gpt_neo", "GPTNeoModel"),
("big_bird", "BigBirdModel"),
("speech_to_text", "Speech2TextModel"),
("vit", "ViTModel"),
("wav2vec2", "Wav2Vec2Model"),
("unispeech-sat", "UniSpeechSatModel"),
("wavlm", "WavLMModel"),
("unispeech", "UniSpeechModel"),
("hubert", "HubertModel"),
("m2m_100", "M2M100Model"),
("convbert", "ConvBertModel"),
("led", "LEDModel"),
("blenderbot-small", "BlenderbotSmallModel"),
("retribert", "RetriBertModel"),
("mt5", "MT5Model"),
("t5", "T5Model"),
("pegasus", "PegasusModel"),
("marian", "MarianModel"),
("mbart", "MBartModel"),
("blenderbot", "BlenderbotModel"),
("distilbert", "DistilBertModel"),
("albert", "AlbertModel"),
("camembert", "CamembertModel"),
("xlm-roberta-xl", "XLMRobertaXLModel"),
("xlm-roberta", "XLMRobertaModel"),
("bart", "BartModel"),
("opt", "OPTModel"),
("longformer", "LongformerModel"),
("roberta", "RobertaModel"),
("data2vec-text", "Data2VecTextModel"),
("data2vec-audio", "Data2VecAudioModel"),
("layoutlm", "LayoutLMModel"),
("squeezebert", "SqueezeBertModel"),
("beit", "BeitModel"),
("bert", "BertModel"),
("openai-gpt", "OpenAIGPTModel"),
("gpt2", "GPT2Model"),
("megatron-bert", "MegatronBertModel"),
("mobilebert", "MobileBertModel"),
("transfo-xl", "TransfoXLModel"),
("xlnet", "XLNetModel"),
("flaubert", "FlaubertModel"),
("fsmt", "FSMTModel"),
("xlm", "XLMModel"),
("ctrl", "CTRLModel"),
("electra", "ElectraModel"),
("reformer", "ReformerModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("lxmert", "LxmertModel"),
("bert-generation", "BertGenerationEncoder"),
("big_bird", "BigBirdModel"),
("bigbird_pegasus", "BigBirdPegasusModel"),
("blenderbot", "BlenderbotModel"),
("blenderbot-small", "BlenderbotSmallModel"),
("camembert", "CamembertModel"),
("canine", "CanineModel"),
("clip", "CLIPModel"),
("convbert", "ConvBertModel"),
("convnext", "ConvNextModel"),
("ctrl", "CTRLModel"),
("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"),
("data2vec-vision", "Data2VecVisionModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("decision_transformer", "DecisionTransformerModel"),
("decision_transformer", "DecisionTransformerModel"),
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
("deit", "DeiTModel"),
("detr", "DetrModel"),
("distilbert", "DistilBertModel"),
("dpr", "DPRQuestionEncoder"),
("xlm-prophetnet", "XLMProphetNetModel"),
("prophetnet", "ProphetNetModel"),
("mpnet", "MPNetModel"),
("tapas", "TapasModel"),
("dpt", "DPTModel"),
("electra", "ElectraModel"),
("flaubert", "FlaubertModel"),
("flava", "FlavaModel"),
("fnet", "FNetModel"),
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("glpn", "GLPNModel"),
("gpt2", "GPT2Model"),
("gpt_neo", "GPTNeoModel"),
("gptj", "GPTJModel"),
("hubert", "HubertModel"),
("ibert", "IBertModel"),
("splinter", "SplinterModel"),
("imagegpt", "ImageGPTModel"),
("layoutlm", "LayoutLMModel"),
("layoutlmv2", "LayoutLMv2Model"),
("led", "LEDModel"),
("longformer", "LongformerModel"),
("luke", "LukeModel"),
("lxmert", "LxmertModel"),
("m2m_100", "M2M100Model"),
("marian", "MarianModel"),
("maskformer", "MaskFormerModel"),
("mbart", "MBartModel"),
("megatron-bert", "MegatronBertModel"),
("mobilebert", "MobileBertModel"),
("mpnet", "MPNetModel"),
("mt5", "MT5Model"),
("nystromformer", "NystromformerModel"),
("openai-gpt", "OpenAIGPTModel"),
("opt", "OPTModel"),
("pegasus", "PegasusModel"),
("perceiver", "PerceiverModel"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
("qdqbert", "QDQBertModel"),
("reformer", "ReformerModel"),
("regnet", "RegNetModel"),
("rembert", "RemBertModel"),
("resnet", "ResNetModel"),
("retribert", "RetriBertModel"),
("roberta", "RobertaModel"),
("roformer", "RoFormerModel"),
("segformer", "SegformerModel"),
("sew", "SEWModel"),
("sew-d", "SEWDModel"),
("speech_to_text", "Speech2TextModel"),
("splinter", "SplinterModel"),
("squeezebert", "SqueezeBertModel"),
("swin", "SwinModel"),
("t5", "T5Model"),
("tapas", "TapasModel"),
("transfo-xl", "TransfoXLModel"),
("unispeech", "UniSpeechModel"),
("unispeech-sat", "UniSpeechSatModel"),
("van", "VanModel"),
("vilt", "ViltModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("visual_bert", "VisualBertModel"),
("vit", "ViTModel"),
("vit_mae", "ViTMAEModel"),
("wav2vec2", "Wav2Vec2Model"),
("wavlm", "WavLMModel"),
("xglm", "XGLMModel"),
("xlm", "XLMModel"),
("xlm-prophetnet", "XLMProphetNetModel"),
("xlm-roberta", "XLMRobertaModel"),
("xlm-roberta-xl", "XLMRobertaXLModel"),
("xlnet", "XLNetModel"),
("yolos", "YolosModel"),
("yoso", "YosoModel"),
]
)
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("flava", "FlavaForPreTraining"),
("vit_mae", "ViTMAEForPreTraining"),
("fnet", "FNetForPreTraining"),
("visual_bert", "VisualBertForPreTraining"),
("layoutlm", "LayoutLMForMaskedLM"),
("retribert", "RetriBertModel"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForPreTraining"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForPreTraining"),
("big_bird", "BigBirdForPreTraining"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("camembert", "CamembertForMaskedLM"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForPreTraining"),
("lxmert", "LxmertForPreTraining"),
("funnel", "FunnelForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("wav2vec2", "Wav2Vec2ForPreTraining"),
("unispeech-sat", "UniSpeechSatForPreTraining"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForPreTraining"),
("flaubert", "FlaubertWithLMHeadModel"),
("flava", "FlavaForPreTraining"),
("fnet", "FNetForPreTraining"),
("fsmt", "FSMTForConditionalGeneration"),
("funnel", "FunnelForPreTraining"),
("gpt2", "GPT2LMHeadModel"),
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("lxmert", "LxmertForPreTraining"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("tapas", "TapasForMaskedLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("unispeech", "UniSpeechForPreTraining"),
("unispeech-sat", "UniSpeechSatForPreTraining"),
("visual_bert", "VisualBertForPreTraining"),
("vit_mae", "ViTMAEForPreTraining"),
("wav2vec2", "Wav2Vec2ForPreTraining"),
("xlm", "XLMWithLMHeadModel"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("xlnet", "XLNetLMHeadModel"),
]
)
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("yoso", "YosoForMaskedLM"),
("nystromformer", "NystromformerForMaskedLM"),
("plbart", "PLBartForConditionalGeneration"),
("qdqbert", "QDQBertForMaskedLM"),
("fnet", "FNetForMaskedLM"),
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForMaskedLM"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("convbert", "ConvBertForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("marian", "MarianMTModel"),
("fsmt", "FSMTForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("megatron-bert", "MegatronBertForCausalLM"),
("mobilebert", "MobileBertForMaskedLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("big_bird", "BigBirdForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("camembert", "CamembertForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("ctrl", "CTRLLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
("reformer", "ReformerModelWithLMHead"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
("flaubert", "FlaubertWithLMHeadModel"),
("fnet", "FNetForMaskedLM"),
("fsmt", "FSMTForConditionalGeneration"),
("funnel", "FunnelForMaskedLM"),
("gpt2", "GPT2LMHeadModel"),
("gpt_neo", "GPTNeoForCausalLM"),
("gptj", "GPTJForCausalLM"),
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("led", "LEDForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("megatron-bert", "MegatronBertForCausalLM"),
("mobilebert", "MobileBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("nystromformer", "NystromformerForMaskedLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("plbart", "PLBartForConditionalGeneration"),
("qdqbert", "QDQBertForMaskedLM"),
("reformer", "ReformerModelWithLMHead"),
("rembert", "RemBertForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("squeezebert", "SqueezeBertForMaskedLM"),
("t5", "T5ForConditionalGeneration"),
("tapas", "TapasForMaskedLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("xlm", "XLMWithLMHeadModel"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("xlnet", "XLNetLMHeadModel"),
("yoso", "YosoForMaskedLM"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("xglm", "XGLMForCausalLM"),
("plbart", "PLBartForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("trocr", "TrOCRForCausalLM"),
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
("gpt_neo", "GPTNeoForCausalLM"),
("big_bird", "BigBirdForCausalLM"),
("camembert", "CamembertForCausalLM"),
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
("xlm-roberta", "XLMRobertaForCausalLM"),
("roberta", "RobertaForCausalLM"),
("bert", "BertLMHeadModel"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("electra", "ElectraForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
("reformer", "ReformerModelWithLMHead"),
("bert-generation", "BertGenerationDecoder"),
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("bart", "BartForCausalLM"),
("opt", "OPTForCausalLM"),
("mbart", "MBartForCausalLM"),
("pegasus", "PegasusForCausalLM"),
("marian", "MarianForCausalLM"),
("bert", "BertLMHeadModel"),
("bert-generation", "BertGenerationDecoder"),
("big_bird", "BigBirdForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
("blenderbot", "BlenderbotForCausalLM"),
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("speech_to_text_2", "Speech2Text2ForCausalLM"),
("camembert", "CamembertForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
("electra", "ElectraForCausalLM"),
("gpt2", "GPT2LMHeadModel"),
("gpt_neo", "GPTNeoForCausalLM"),
("gptj", "GPTJForCausalLM"),
("marian", "MarianForCausalLM"),
("mbart", "MBartForCausalLM"),
("megatron-bert", "MegatronBertForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
("opt", "OPTForCausalLM"),
("pegasus", "PegasusForCausalLM"),
("plbart", "PLBartForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("reformer", "ReformerModelWithLMHead"),
("rembert", "RemBertForCausalLM"),
("roberta", "RobertaForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("speech_to_text_2", "Speech2Text2ForCausalLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("trocr", "TrOCRForCausalLM"),
("xglm", "XGLMForCausalLM"),
("xlm", "XLMWithLMHeadModel"),
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
("xlm-roberta", "XLMRobertaForCausalLM"),
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
("xlnet", "XLNetLMHeadModel"),
]
)
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("vit", "ViTForMaskedImageModeling"),
("deit", "DeiTForMaskedImageModeling"),
("swin", "SwinForMaskedImageModeling"),
("vit", "ViTForMaskedImageModeling"),
]
)
@ -293,11 +293,10 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image Classification mapping
("vit", "ViTForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
("beit", "BeitForImageClassification"),
("convnext", "ConvNextForImageClassification"),
("data2vec-vision", "Data2VecVisionForImageClassification"),
("segformer", "SegformerForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
("imagegpt", "ImageGPTForImageClassification"),
(
"perceiver",
@ -307,12 +306,13 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
"PerceiverForImageClassificationConvProcessing",
),
),
("swin", "SwinForImageClassification"),
("convnext", "ConvNextForImageClassification"),
("van", "VanForImageClassification"),
("resnet", "ResNetForImageClassification"),
("regnet", "RegNetForImageClassification"),
("poolformer", "PoolFormerForImageClassification"),
("regnet", "RegNetForImageClassification"),
("resnet", "ResNetForImageClassification"),
("segformer", "SegformerForImageClassification"),
("swin", "SwinForImageClassification"),
("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"),
]
)
@ -329,8 +329,8 @@ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
# Model for Semantic Segmentation mapping
("beit", "BeitForSemanticSegmentation"),
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
("dpt", "DPTForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
]
)
@ -350,72 +350,71 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("yoso", "YosoForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("bert", "BertForMaskedLM"),
("big_bird", "BigBirdForMaskedLM"),
("camembert", "CamembertForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForMaskedLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("fnet", "FNetForMaskedLM"),
("funnel", "FunnelForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("mbart", "MBartForConditionalGeneration"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("nystromformer", "NystromformerForMaskedLM"),
("perceiver", "PerceiverForMaskedLM"),
("qdqbert", "QDQBertForMaskedLM"),
("fnet", "FNetForMaskedLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("big_bird", "BigBirdForMaskedLM"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("convbert", "ConvBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("distilbert", "DistilBertForMaskedLM"),
("albert", "AlbertForMaskedLM"),
("bart", "BartForConditionalGeneration"),
("mbart", "MBartForConditionalGeneration"),
("camembert", "CamembertForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("bert", "BertForMaskedLM"),
("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"),
("flaubert", "FlaubertWithLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("electra", "ElectraForMaskedLM"),
("reformer", "ReformerForMaskedLM"),
("funnel", "FunnelForMaskedLM"),
("mpnet", "MPNetForMaskedLM"),
("rembert", "RemBertForMaskedLM"),
("roberta", "RobertaForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("squeezebert", "SqueezeBertForMaskedLM"),
("tapas", "TapasForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("ibert", "IBertForMaskedLM"),
("wav2vec2", "Wav2Vec2ForMaskedLM"),
("xlm", "XLMWithLMHeadModel"),
("xlm-roberta", "XLMRobertaForMaskedLM"),
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
("yoso", "YosoForMaskedLM"),
]
)
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Object Detection mapping
("yolos", "YolosForObjectDetection"),
("detr", "DetrForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("tapex", "BartForConditionalGeneration"),
("plbart", "PLBartForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("fsmt", "FSMTForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("mbart", "MBartForConditionalGeneration"),
("blenderbot", "BlenderbotForConditionalGeneration"),
("bart", "BartForConditionalGeneration"),
("fsmt", "FSMTForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
("mt5", "MT5ForConditionalGeneration"),
("pegasus", "PegasusForConditionalGeneration"),
("plbart", "PLBartForConditionalGeneration"),
("prophetnet", "ProphetNetForConditionalGeneration"),
("t5", "T5ForConditionalGeneration"),
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
]
)
@ -429,98 +428,97 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("tapex", "BartForSequenceClassification"),
("yoso", "YosoForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"),
("plbart", "PLBartForSequenceClassification"),
("perceiver", "PerceiverForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"),
("fnet", "FNetForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("distilbert", "DistilBertForSequenceClassification"),
("albert", "AlbertForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("bart", "BartForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
("data2vec-text", "Data2VecTextForSequenceClassification"),
("squeezebert", "SqueezeBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("bert", "BertForSequenceClassification"),
("xlnet", "XLNetForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("flaubert", "FlaubertForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("electra", "ElectraForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("big_bird", "BigBirdForSequenceClassification"),
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
("camembert", "CamembertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
("convbert", "ConvBertForSequenceClassification"),
("ctrl", "CTRLForSequenceClassification"),
("data2vec-text", "Data2VecTextForSequenceClassification"),
("deberta", "DebertaForSequenceClassification"),
("deberta-v2", "DebertaV2ForSequenceClassification"),
("distilbert", "DistilBertForSequenceClassification"),
("electra", "ElectraForSequenceClassification"),
("flaubert", "FlaubertForSequenceClassification"),
("fnet", "FNetForSequenceClassification"),
("funnel", "FunnelForSequenceClassification"),
("gpt2", "GPT2ForSequenceClassification"),
("gpt_neo", "GPTNeoForSequenceClassification"),
("openai-gpt", "OpenAIGPTForSequenceClassification"),
("reformer", "ReformerForSequenceClassification"),
("ctrl", "CTRLForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("tapas", "TapasForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("mpnet", "MPNetForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"),
("openai-gpt", "OpenAIGPTForSequenceClassification"),
("perceiver", "PerceiverForSequenceClassification"),
("plbart", "PLBartForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"),
("reformer", "ReformerForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("roberta", "RobertaForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
("squeezebert", "SqueezeBertForSequenceClassification"),
("tapas", "TapasForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
("xlnet", "XLNetForSequenceClassification"),
("yoso", "YosoForSequenceClassification"),
]
)
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("yoso", "YosoForQuestionAnswering"),
("nystromformer", "NystromformerForQuestionAnswering"),
("qdqbert", "QDQBertForQuestionAnswering"),
("fnet", "FNetForQuestionAnswering"),
("gptj", "GPTJForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),
("roformer", "RoFormerForQuestionAnswering"),
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
("big_bird", "BigBirdForQuestionAnswering"),
("convbert", "ConvBertForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("distilbert", "DistilBertForQuestionAnswering"),
("albert", "AlbertForQuestionAnswering"),
("camembert", "CamembertForQuestionAnswering"),
("bart", "BartForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),
("squeezebert", "SqueezeBertForQuestionAnswering"),
("bert", "BertForQuestionAnswering"),
("xlnet", "XLNetForQuestionAnsweringSimple"),
("flaubert", "FlaubertForQuestionAnsweringSimple"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("xlm", "XLMForQuestionAnsweringSimple"),
("electra", "ElectraForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"),
("funnel", "FunnelForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("big_bird", "BigBirdForQuestionAnswering"),
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
("camembert", "CamembertForQuestionAnswering"),
("canine", "CanineForQuestionAnswering"),
("convbert", "ConvBertForQuestionAnswering"),
("data2vec-text", "Data2VecTextForQuestionAnswering"),
("deberta", "DebertaForQuestionAnswering"),
("deberta-v2", "DebertaV2ForQuestionAnswering"),
("distilbert", "DistilBertForQuestionAnswering"),
("electra", "ElectraForQuestionAnswering"),
("flaubert", "FlaubertForQuestionAnsweringSimple"),
("fnet", "FNetForQuestionAnswering"),
("funnel", "FunnelForQuestionAnswering"),
("gptj", "GPTJForQuestionAnswering"),
("ibert", "IBertForQuestionAnswering"),
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
("mpnet", "MPNetForQuestionAnswering"),
("nystromformer", "NystromformerForQuestionAnswering"),
("qdqbert", "QDQBertForQuestionAnswering"),
("reformer", "ReformerForQuestionAnswering"),
("rembert", "RemBertForQuestionAnswering"),
("roberta", "RobertaForQuestionAnswering"),
("roformer", "RoFormerForQuestionAnswering"),
("splinter", "SplinterForQuestionAnswering"),
("data2vec-text", "Data2VecTextForQuestionAnswering"),
("squeezebert", "SqueezeBertForQuestionAnswering"),
("xlm", "XLMForQuestionAnsweringSimple"),
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
("xlnet", "XLNetForQuestionAnsweringSimple"),
("yoso", "YosoForQuestionAnswering"),
]
)
@ -534,132 +532,132 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("yoso", "YosoForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("qdqbert", "QDQBertForTokenClassification"),
("fnet", "FNetForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("rembert", "RemBertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("distilbert", "DistilBertForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("flaubert", "FlaubertForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("roberta", "RobertaForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("bert", "BertForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("xlnet", "XLNetForTokenClassification"),
("albert", "AlbertForTokenClassification"),
("electra", "ElectraForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("bert", "BertForTokenClassification"),
("big_bird", "BigBirdForTokenClassification"),
("camembert", "CamembertForTokenClassification"),
("canine", "CanineForTokenClassification"),
("convbert", "ConvBertForTokenClassification"),
("data2vec-text", "Data2VecTextForTokenClassification"),
("deberta", "DebertaForTokenClassification"),
("deberta-v2", "DebertaV2ForTokenClassification"),
("distilbert", "DistilBertForTokenClassification"),
("electra", "ElectraForTokenClassification"),
("flaubert", "FlaubertForTokenClassification"),
("fnet", "FNetForTokenClassification"),
("funnel", "FunnelForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("ibert", "IBertForTokenClassification"),
("data2vec-text", "Data2VecTextForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("qdqbert", "QDQBertForTokenClassification"),
("rembert", "RemBertForTokenClassification"),
("roberta", "RobertaForTokenClassification"),
("roformer", "RoFormerForTokenClassification"),
("squeezebert", "SqueezeBertForTokenClassification"),
("xlm", "XLMForTokenClassification"),
("xlm-roberta", "XLMRobertaForTokenClassification"),
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
("xlnet", "XLNetForTokenClassification"),
("yoso", "YosoForTokenClassification"),
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("yoso", "YosoForMultipleChoice"),
("nystromformer", "NystromformerForMultipleChoice"),
("qdqbert", "QDQBertForMultipleChoice"),
("fnet", "FNetForMultipleChoice"),
("rembert", "RemBertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("data2vec-text", "Data2VecTextForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("albert", "AlbertForMultipleChoice"),
("bert", "BertForMultipleChoice"),
("big_bird", "BigBirdForMultipleChoice"),
("camembert", "CamembertForMultipleChoice"),
("canine", "CanineForMultipleChoice"),
("convbert", "ConvBertForMultipleChoice"),
("data2vec-text", "Data2VecTextForMultipleChoice"),
("deberta-v2", "DebertaV2ForMultipleChoice"),
("distilbert", "DistilBertForMultipleChoice"),
("electra", "ElectraForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("fnet", "FNetForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("albert", "AlbertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("flaubert", "FlaubertForMultipleChoice"),
("funnel", "FunnelForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
("deberta-v2", "DebertaV2ForMultipleChoice"),
("nystromformer", "NystromformerForMultipleChoice"),
("qdqbert", "QDQBertForMultipleChoice"),
("rembert", "RemBertForMultipleChoice"),
("roberta", "RobertaForMultipleChoice"),
("roformer", "RoFormerForMultipleChoice"),
("squeezebert", "SqueezeBertForMultipleChoice"),
("xlm", "XLMForMultipleChoice"),
("xlm-roberta", "XLMRobertaForMultipleChoice"),
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
("xlnet", "XLNetForMultipleChoice"),
("yoso", "YosoForMultipleChoice"),
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("qdqbert", "QDQBertForNextSentencePrediction"),
("bert", "BertForNextSentencePrediction"),
("fnet", "FNetForNextSentencePrediction"),
("megatron-bert", "MegatronBertForNextSentencePrediction"),
("mobilebert", "MobileBertForNextSentencePrediction"),
("qdqbert", "QDQBertForNextSentencePrediction"),
]
)
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
("unispeech", "UniSpeechForSequenceClassification"),
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
("hubert", "HubertForSequenceClassification"),
("sew", "SEWForSequenceClassification"),
("sew-d", "SEWDForSequenceClassification"),
("unispeech", "UniSpeechForSequenceClassification"),
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
("wavlm", "WavLMForSequenceClassification"),
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
]
)
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
[
# Model for Connectionist temporal classification (CTC) mapping
("wav2vec2", "Wav2Vec2ForCTC"),
("unispeech-sat", "UniSpeechSatForCTC"),
("unispeech", "UniSpeechForCTC"),
("data2vec-audio", "Data2VecAudioForCTC"),
("hubert", "HubertForCTC"),
("sew", "SEWForCTC"),
("sew-d", "SEWDForCTC"),
("unispeech", "UniSpeechForCTC"),
("unispeech-sat", "UniSpeechSatForCTC"),
("wav2vec2", "Wav2Vec2ForCTC"),
("wavlm", "WavLMForCTC"),
("data2vec-audio", "Data2VecAudioForCTC"),
]
)
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
("wavlm", "WavLMForAudioFrameClassification"),
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
("wavlm", "WavLMForAudioFrameClassification"),
]
)
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
[
# Model for Audio Classification mapping
("wav2vec2", "Wav2Vec2ForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
("wavlm", "WavLMForXVector"),
("data2vec-audio", "Data2VecAudioForXVector"),
("unispeech-sat", "UniSpeechSatForXVector"),
("wav2vec2", "Wav2Vec2ForXVector"),
("wavlm", "WavLMForXVector"),
]
)

View File

@ -28,31 +28,31 @@ logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("xglm", "FlaxXGLMModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("pegasus", "FlaxPegasusModel"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
("bert", "FlaxBertModel"),
("beit", "FlaxBeitModel"),
("big_bird", "FlaxBigBirdModel"),
("bart", "FlaxBartModel"),
("beit", "FlaxBeitModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("electra", "FlaxElectraModel"),
("clip", "FlaxCLIPModel"),
("vit", "FlaxViTModel"),
("mbart", "FlaxMBartModel"),
("t5", "FlaxT5Model"),
("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
("blenderbot", "FlaxBlenderbotModel"),
("mbart", "FlaxMBartModel"),
("mt5", "FlaxMT5Model"),
("pegasus", "FlaxPegasusModel"),
("roberta", "FlaxRobertaModel"),
("roformer", "FlaxRoFormerModel"),
("t5", "FlaxT5Model"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
]
)
@ -60,56 +60,56 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("albert", "FlaxAlbertForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForPreTraining"),
("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"),
("albert", "FlaxAlbertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("distilbert", "FlaxDistilBertForMaskedLM"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("bart", "FlaxBartForConditionalGeneration"),
("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("marian", "FlaxMarianMTModel"),
("mbart", "FlaxMBartForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("vit", "FlaxViTForImageClassification"),
("beit", "FlaxBeitForImageClassification"),
("vit", "FlaxViTForImageClassification"),
]
)
@ -122,75 +122,75 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"),
("albert", "FlaxAlbertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("distilbert", "FlaxDistilBertForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
]
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("albert", "FlaxAlbertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("roformer", "FlaxRoFormerForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
]
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"),
("albert", "FlaxAlbertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("distilbert", "FlaxDistilBertForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("roformer", "FlaxRoFormerForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
]
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"),
("albert", "FlaxAlbertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("distilbert", "FlaxDistilBertForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("roformer", "FlaxRoFormerForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
]
)

View File

@ -29,142 +29,142 @@ logger = logging.get_logger(__name__)
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("speech_to_text", "TFSpeech2TextModel"),
("clip", "TFCLIPModel"),
("deberta-v2", "TFDebertaV2Model"),
("deberta", "TFDebertaModel"),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
("t5", "TFT5Model"),
("distilbert", "TFDistilBertModel"),
("albert", "TFAlbertModel"),
("bart", "TFBartModel"),
("camembert", "TFCamembertModel"),
("xlm-roberta", "TFXLMRobertaModel"),
("longformer", "TFLongformerModel"),
("roberta", "TFRobertaModel"),
("layoutlm", "TFLayoutLMModel"),
("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"),
("flaubert", "TFFlaubertModel"),
("xlm", "TFXLMModel"),
("ctrl", "TFCTRLModel"),
("electra", "TFElectraModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("dpr", "TFDPRQuestionEncoder"),
("mpnet", "TFMPNetModel"),
("tapas", "TFTapasModel"),
("mbart", "TFMBartModel"),
("marian", "TFMarianModel"),
("pegasus", "TFPegasusModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("ctrl", "TFCTRLModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("electra", "TFElectraModel"),
("flaubert", "TFFlaubertModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("hubert", "TFHubertModel"),
("layoutlm", "TFLayoutLMModel"),
("led", "TFLEDModel"),
("longformer", "TFLongformerModel"),
("lxmert", "TFLxmertModel"),
("marian", "TFMarianModel"),
("mbart", "TFMBartModel"),
("mobilebert", "TFMobileBertModel"),
("mpnet", "TFMPNetModel"),
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
("pegasus", "TFPegasusModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("t5", "TFT5Model"),
("tapas", "TFTapasModel"),
("transfo-xl", "TFTransfoXLModel"),
("vit", "TFViTModel"),
("vit_mae", "TFViTMAEModel"),
("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"),
("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"),
("xlnet", "TFXLNetModel"),
]
)
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("lxmert", "TFLxmertForPreTraining"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForPreTraining"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForPreTraining"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("camembert", "TFCamembertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForPreTraining"),
("tapas", "TFTapasForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForPreTraining"),
("gpt2", "TFGPT2LMHeadModel"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("lxmert", "TFLxmertForPreTraining"),
("mobilebert", "TFMobileBertForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("roberta", "TFRobertaForMaskedLM"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("vit_mae", "TFViTMAEForPreTraining"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("marian", "TFMarianMTModel"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("longformer", "TFLongformerForMaskedLM"),
("marian", "TFMarianMTModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("camembert", "TFCamembertForCausalLM"),
("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("camembert", "TFCamembertForCausalLM"),
("ctrl", "TFCTRLLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("rembert", "TFRemBertForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("vit", "TFViTForImageClassification"),
]
)
@ -177,42 +177,42 @@ TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("deberta", "TFDebertaForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("deberta", "TFDebertaForMaskedLM"),
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("led", "TFLEDForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"),
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"),
("encoder-decoder", "TFEncoderDecoderModel"),
("led", "TFLEDForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
]
)
@ -225,58 +225,58 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("deberta", "TFDebertaForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("albert", "TFAlbertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("bert", "TFBertForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
("deberta", "TFDebertaForSequenceClassification"),
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("tapas", "TFTapasForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("gptj", "TFGPTJForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("tapas", "TFTapasForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
("deberta", "TFDebertaForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("albert", "TFAlbertForQuestionAnswering"),
("camembert", "TFCamembertForQuestionAnswering"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("bert", "TFBertForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("camembert", "TFCamembertForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("deberta", "TFDebertaForQuestionAnswering"),
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("electra", "TFElectraForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("funnel", "TFFunnelForQuestionAnswering"),
("gptj", "TFGPTJForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
]
)
@ -291,49 +291,49 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("deberta", "TFDebertaForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("albert", "TFAlbertForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("deberta", "TFDebertaForTokenClassification"),
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("rembert", "TFRemBertForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("albert", "TFAlbertForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
("rembert", "TFRemBertForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("albert", "TFAlbertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
]
)

View File

@ -41,17 +41,17 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("flava", "FLAVAProcessor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutxlm", "LayoutXLMProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("unispeech", "Wav2Vec2Processor"),
("unispeech-sat", "Wav2Vec2Processor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("vilt", "ViltProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("wavlm", "Wav2Vec2Processor"),
]
)
@ -65,7 +65,10 @@ def processor_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
try:
return getattr(module, class_name)
except AttributeError:
continue
for processor in PROCESSOR_MAPPING._extra_content.values():
if getattr(processor, "__name__", None) == class_name:

View File

@ -46,27 +46,6 @@ if TYPE_CHECKING:
else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
(
"t5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mt5",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"albert",
(
@ -74,6 +53,30 @@ else:
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("bart", ("BartTokenizer", "BartTokenizerFast")),
(
"barthez",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
),
),
("bartpho", ("BartphoTokenizer", None)),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
(
"big_bird",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
),
),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("byt5", ("ByT5Tokenizer", None)),
(
"camembert",
(
@ -81,76 +84,24 @@ else:
"CamembertTokenizerFast" if is_tokenizers_available() else None,
),
),
("canine", ("CanineTokenizer", None)),
(
"pegasus",
"clip",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"mbart",
"cpm",
(
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"xlm-roberta",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("tapex", ("TapexTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
(
"DPRQuestionEncoderTokenizer",
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("opt", ("GPT2Tokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
("flaubert", ("FlaubertTokenizer", None)),
("xlm", ("XLMTokenizer", None)),
("ctrl", ("CTRLTokenizer", None)),
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
(
"deberta-v2",
@ -159,51 +110,39 @@ else:
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
),
),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"big_bird",
"dpr",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
"DPRQuestionEncoderTokenizer",
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
),
),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("flaubert", ("FlaubertTokenizer", None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
("byt5", ("ByT5Tokenizer", None)),
(
"cpm",
(
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("bartpho", ("BartphoTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
(
"barthez",
"mbart",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
),
),
(
@ -213,37 +152,17 @@ else:
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"rembert",
(
"RemBertTokenizer" if is_sentencepiece_available() else None,
"RemBertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"clip",
(
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
(
"perceiver",
(
"PerceiverTokenizer",
None,
),
),
(
"xglm",
(
"XGLMTokenizer" if is_sentencepiece_available() else None,
"XGLMTokenizerFast" if is_tokenizers_available() else None,
),
),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
(
"mt5",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"nystromformer",
(
@ -251,7 +170,89 @@ else:
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("opt", ("GPT2Tokenizer", None)),
(
"pegasus",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"perceiver",
(
"PerceiverTokenizer",
None,
),
),
("phobert", ("PhobertTokenizer", None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"rembert",
(
"RemBertTokenizer" if is_sentencepiece_available() else None,
"RemBertTokenizerFast" if is_tokenizers_available() else None,
),
),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
(
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
),
(
"t5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
("tapas", ("TapasTokenizer", None)),
("tapex", ("TapexTokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
(
"xglm",
(
"XGLMTokenizer" if is_sentencepiece_available() else None,
"XGLMTokenizerFast" if is_tokenizers_available() else None,
),
),
("xlm", ("XLMTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
(
"xlm-roberta",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"yoso",
(
@ -259,7 +260,6 @@ else:
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
]
)
@ -277,7 +277,10 @@ def tokenizer_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
try:
return getattr(module, class_name)
except AttributeError:
continue
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers:

View File

@ -14,6 +14,8 @@
# limitations under the License.
import importlib
import json
import os
import sys
import tempfile
import unittest
@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase):
self.assertIsInstance(config, RobertaConfig)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(CONFIG_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
with tempfile.TemporaryDirectory() as tmp_dir:
# This model name contains bert and roberta, but roberta ends up being picked.
folder = os.path.join(tmp_dir, "fake-roberta")
os.makedirs(folder, exist_ok=True)
with open(os.path.join(folder, "config.json"), "w") as f:
f.write(json.dumps({}))
config = AutoConfig.from_pretrained(folder)
self.assertEqual(type(config), RobertaConfig)
def test_new_config_registration(self):
try:

View File

@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
# re pattern that matches mapping introductions:
# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
# re pattern that matches identifiers in mappings
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
def sort_auto_mapping(fname, overwrite: bool = False):
with open(fname, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
new_lines = []
line_idx = 0
while line_idx < len(lines):
if _re_intro_mapping.search(lines[line_idx]) is not None:
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
# Start of a new mapping!
while not lines[line_idx].startswith(" " * indent + "("):
new_lines.append(lines[line_idx])
line_idx += 1
blocks = []
while lines[line_idx].strip() != "]":
# Blocks either fit in one line or not
if lines[line_idx].strip() == "(":
start_idx = line_idx
while not lines[line_idx].startswith(" " * indent + ")"):
line_idx += 1
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
else:
blocks.append(lines[line_idx])
line_idx += 1
# Sort blocks by their identifiers
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
new_lines += blocks
else:
new_lines.append(lines[line_idx])
line_idx += 1
if overwrite:
with open(fname, "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
elif "\n".join(new_lines) != content:
return True
def sort_all_auto_mappings(overwrite: bool = False):
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
if not overwrite and any(diffs):
failures = [f for f, d in zip(fnames, diffs) if d]
raise ValueError(
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
" this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_all_auto_mappings(not args.check_only)