mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Automatically sort auto mappings (#17250)
* Automatically sort auto mappings * Better class extraction * Some auto class magic * Adapt test and underlying behavior * Remove re-used config * Quality
This commit is contained in:
parent
2f611f85e2
commit
ddb1a47ec8
@ -857,6 +857,7 @@ jobs:
|
||||
- run: black --check --preview examples tests src utils
|
||||
- run: isort --check-only examples tests src utils
|
||||
- run: python utils/custom_init_isort.py --check_only
|
||||
- run: python utils/sort_auto_mappings.py --check_only
|
||||
- run: flake8 examples tests src utils
|
||||
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
|
||||
|
2
Makefile
2
Makefile
@ -48,6 +48,7 @@ quality:
|
||||
black --check --preview $(check_dirs)
|
||||
isort --check-only $(check_dirs)
|
||||
python utils/custom_init_isort.py --check_only
|
||||
python utils/sort_auto_mappings.py --check_only
|
||||
flake8 $(check_dirs)
|
||||
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
|
||||
|
||||
@ -55,6 +56,7 @@ quality:
|
||||
|
||||
extra_style_checks:
|
||||
python utils/custom_init_isort.py
|
||||
python utils/sort_auto_mappings.py
|
||||
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
|
||||
|
||||
# this target runs checks on all files and potentially modifies some of them
|
||||
|
@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Swin | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| TAPEX | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
@ -74,7 +74,6 @@ Ready-made configurations include the following architectures:
|
||||
- RoBERTa
|
||||
- RoFormer
|
||||
- T5
|
||||
- TAPEX
|
||||
- ViT
|
||||
- XLM-RoBERTa
|
||||
- XLM-RoBERTa-XL
|
||||
|
@ -560,10 +560,17 @@ class _LazyAutoMapping(OrderedDict):
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type not in self._model_mapping:
|
||||
raise KeyError(key)
|
||||
model_name = self._model_mapping[model_type]
|
||||
return self._load_attr_from_module(model_type, model_name)
|
||||
if model_type in self._model_mapping:
|
||||
model_name = self._model_mapping[model_type]
|
||||
return self._load_attr_from_module(model_type, model_name)
|
||||
|
||||
# Maybe there was several model types associated with this config.
|
||||
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
||||
for mtype in model_types:
|
||||
if mtype in self._model_mapping:
|
||||
model_name = self._model_mapping[mtype]
|
||||
return self._load_attr_from_module(mtype, model_name)
|
||||
raise KeyError(key)
|
||||
|
||||
def _load_attr_from_module(self, model_type, attr):
|
||||
module_name = model_type_to_module_name(model_type)
|
||||
|
@ -29,339 +29,338 @@ logger = logging.get_logger(__name__)
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("yolos", "YolosConfig"),
|
||||
("tapex", "BartConfig"),
|
||||
("dpt", "DPTConfig"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("glpn", "GLPNConfig"),
|
||||
("maskformer", "MaskFormerConfig"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("poolformer", "PoolFormerConfig"),
|
||||
("convnext", "ConvNextConfig"),
|
||||
("van", "VanConfig"),
|
||||
("resnet", "ResNetConfig"),
|
||||
("regnet", "RegNetConfig"),
|
||||
("yoso", "YosoConfig"),
|
||||
("swin", "SwinConfig"),
|
||||
("vilt", "ViltConfig"),
|
||||
("vit_mae", "ViTMAEConfig"),
|
||||
("realm", "RealmConfig"),
|
||||
("nystromformer", "NystromformerConfig"),
|
||||
("xglm", "XGLMConfig"),
|
||||
("imagegpt", "ImageGPTConfig"),
|
||||
("qdqbert", "QDQBertConfig"),
|
||||
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
|
||||
("trocr", "TrOCRConfig"),
|
||||
("fnet", "FNetConfig"),
|
||||
("segformer", "SegformerConfig"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
|
||||
("perceiver", "PerceiverConfig"),
|
||||
("gptj", "GPTJConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("plbart", "PLBartConfig"),
|
||||
("beit", "BeitConfig"),
|
||||
("data2vec-vision", "Data2VecVisionConfig"),
|
||||
("rembert", "RemBertConfig"),
|
||||
("visual_bert", "VisualBertConfig"),
|
||||
("canine", "CanineConfig"),
|
||||
("roformer", "RoFormerConfig"),
|
||||
("clip", "CLIPConfig"),
|
||||
("flava", "FlavaConfig"),
|
||||
("bigbird_pegasus", "BigBirdPegasusConfig"),
|
||||
("deit", "DeiTConfig"),
|
||||
("luke", "LukeConfig"),
|
||||
("detr", "DetrConfig"),
|
||||
("gpt_neo", "GPTNeoConfig"),
|
||||
("big_bird", "BigBirdConfig"),
|
||||
("speech_to_text_2", "Speech2Text2Config"),
|
||||
("speech_to_text", "Speech2TextConfig"),
|
||||
("vit", "ViTConfig"),
|
||||
("wav2vec2", "Wav2Vec2Config"),
|
||||
("m2m_100", "M2M100Config"),
|
||||
("convbert", "ConvBertConfig"),
|
||||
("led", "LEDConfig"),
|
||||
("blenderbot-small", "BlenderbotSmallConfig"),
|
||||
("retribert", "RetriBertConfig"),
|
||||
("ibert", "IBertConfig"),
|
||||
("mt5", "MT5Config"),
|
||||
("t5", "T5Config"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
("distilbert", "DistilBertConfig"),
|
||||
("albert", "AlbertConfig"),
|
||||
("bert-generation", "BertGenerationConfig"),
|
||||
("camembert", "CamembertConfig"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLConfig"),
|
||||
("xlm-roberta", "XLMRobertaConfig"),
|
||||
("pegasus", "PegasusConfig"),
|
||||
("marian", "MarianConfig"),
|
||||
("mbart", "MBartConfig"),
|
||||
("megatron-bert", "MegatronBertConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("bart", "BartConfig"),
|
||||
("opt", "OPTConfig"),
|
||||
("blenderbot", "BlenderbotConfig"),
|
||||
("reformer", "ReformerConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("roberta", "RobertaConfig"),
|
||||
("deberta-v2", "DebertaV2Config"),
|
||||
("deberta", "DebertaConfig"),
|
||||
("flaubert", "FlaubertConfig"),
|
||||
("fsmt", "FSMTConfig"),
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("hubert", "HubertConfig"),
|
||||
("beit", "BeitConfig"),
|
||||
("bert", "BertConfig"),
|
||||
("openai-gpt", "OpenAIGPTConfig"),
|
||||
("gpt2", "GPT2Config"),
|
||||
("transfo-xl", "TransfoXLConfig"),
|
||||
("xlnet", "XLNetConfig"),
|
||||
("xlm-prophetnet", "XLMProphetNetConfig"),
|
||||
("prophetnet", "ProphetNetConfig"),
|
||||
("xlm", "XLMConfig"),
|
||||
("bert-generation", "BertGenerationConfig"),
|
||||
("big_bird", "BigBirdConfig"),
|
||||
("bigbird_pegasus", "BigBirdPegasusConfig"),
|
||||
("blenderbot", "BlenderbotConfig"),
|
||||
("blenderbot-small", "BlenderbotSmallConfig"),
|
||||
("camembert", "CamembertConfig"),
|
||||
("canine", "CanineConfig"),
|
||||
("clip", "CLIPConfig"),
|
||||
("convbert", "ConvBertConfig"),
|
||||
("convnext", "ConvNextConfig"),
|
||||
("ctrl", "CTRLConfig"),
|
||||
("electra", "ElectraConfig"),
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
||||
("encoder-decoder", "EncoderDecoderConfig"),
|
||||
("funnel", "FunnelConfig"),
|
||||
("lxmert", "LxmertConfig"),
|
||||
("dpr", "DPRConfig"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("rag", "RagConfig"),
|
||||
("tapas", "TapasConfig"),
|
||||
("splinter", "SplinterConfig"),
|
||||
("sew-d", "SEWDConfig"),
|
||||
("sew", "SEWConfig"),
|
||||
("unispeech-sat", "UniSpeechSatConfig"),
|
||||
("unispeech", "UniSpeechConfig"),
|
||||
("wavlm", "WavLMConfig"),
|
||||
("data2vec-audio", "Data2VecAudioConfig"),
|
||||
("data2vec-text", "Data2VecTextConfig"),
|
||||
("data2vec-vision", "Data2VecVisionConfig"),
|
||||
("deberta", "DebertaConfig"),
|
||||
("deberta-v2", "DebertaV2Config"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("deit", "DeiTConfig"),
|
||||
("detr", "DetrConfig"),
|
||||
("distilbert", "DistilBertConfig"),
|
||||
("dpr", "DPRConfig"),
|
||||
("dpt", "DPTConfig"),
|
||||
("electra", "ElectraConfig"),
|
||||
("encoder-decoder", "EncoderDecoderConfig"),
|
||||
("flaubert", "FlaubertConfig"),
|
||||
("flava", "FlavaConfig"),
|
||||
("fnet", "FNetConfig"),
|
||||
("fsmt", "FSMTConfig"),
|
||||
("funnel", "FunnelConfig"),
|
||||
("glpn", "GLPNConfig"),
|
||||
("gpt2", "GPT2Config"),
|
||||
("gpt_neo", "GPTNeoConfig"),
|
||||
("gptj", "GPTJConfig"),
|
||||
("hubert", "HubertConfig"),
|
||||
("ibert", "IBertConfig"),
|
||||
("imagegpt", "ImageGPTConfig"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("led", "LEDConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("luke", "LukeConfig"),
|
||||
("lxmert", "LxmertConfig"),
|
||||
("m2m_100", "M2M100Config"),
|
||||
("marian", "MarianConfig"),
|
||||
("maskformer", "MaskFormerConfig"),
|
||||
("mbart", "MBartConfig"),
|
||||
("megatron-bert", "MegatronBertConfig"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("mt5", "MT5Config"),
|
||||
("nystromformer", "NystromformerConfig"),
|
||||
("openai-gpt", "OpenAIGPTConfig"),
|
||||
("opt", "OPTConfig"),
|
||||
("pegasus", "PegasusConfig"),
|
||||
("perceiver", "PerceiverConfig"),
|
||||
("plbart", "PLBartConfig"),
|
||||
("poolformer", "PoolFormerConfig"),
|
||||
("prophetnet", "ProphetNetConfig"),
|
||||
("qdqbert", "QDQBertConfig"),
|
||||
("rag", "RagConfig"),
|
||||
("realm", "RealmConfig"),
|
||||
("reformer", "ReformerConfig"),
|
||||
("regnet", "RegNetConfig"),
|
||||
("rembert", "RemBertConfig"),
|
||||
("resnet", "ResNetConfig"),
|
||||
("retribert", "RetriBertConfig"),
|
||||
("roberta", "RobertaConfig"),
|
||||
("roformer", "RoFormerConfig"),
|
||||
("segformer", "SegformerConfig"),
|
||||
("sew", "SEWConfig"),
|
||||
("sew-d", "SEWDConfig"),
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
||||
("speech_to_text", "Speech2TextConfig"),
|
||||
("speech_to_text_2", "Speech2Text2Config"),
|
||||
("splinter", "SplinterConfig"),
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("swin", "SwinConfig"),
|
||||
("t5", "T5Config"),
|
||||
("tapas", "TapasConfig"),
|
||||
("transfo-xl", "TransfoXLConfig"),
|
||||
("trocr", "TrOCRConfig"),
|
||||
("unispeech", "UniSpeechConfig"),
|
||||
("unispeech-sat", "UniSpeechSatConfig"),
|
||||
("van", "VanConfig"),
|
||||
("vilt", "ViltConfig"),
|
||||
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
|
||||
("visual_bert", "VisualBertConfig"),
|
||||
("vit", "ViTConfig"),
|
||||
("vit_mae", "ViTMAEConfig"),
|
||||
("wav2vec2", "Wav2Vec2Config"),
|
||||
("wavlm", "WavLMConfig"),
|
||||
("xglm", "XGLMConfig"),
|
||||
("xlm", "XLMConfig"),
|
||||
("xlm-prophetnet", "XLMProphetNetConfig"),
|
||||
("xlm-roberta", "XLMRobertaConfig"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLConfig"),
|
||||
("xlnet", "XLNetConfig"),
|
||||
("yolos", "YolosConfig"),
|
||||
("yoso", "YosoConfig"),
|
||||
]
|
||||
)
|
||||
|
||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add archive maps here)
|
||||
("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("sew-d", "SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("speech_to_text_2", "SPEECH_TO_TEXT_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xglm", "XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("yolos", "YOLOS"),
|
||||
("tapex", "TAPEX"),
|
||||
("dpt", "DPT"),
|
||||
("decision_transformer", "Decision Transformer"),
|
||||
("glpn", "GLPN"),
|
||||
("maskformer", "MaskFormer"),
|
||||
("poolformer", "PoolFormer"),
|
||||
("convnext", "ConvNext"),
|
||||
("van", "VAN"),
|
||||
("resnet", "ResNet"),
|
||||
("regnet", "RegNet"),
|
||||
("yoso", "YOSO"),
|
||||
("swin", "Swin"),
|
||||
("vilt", "ViLT"),
|
||||
("vit_mae", "ViTMAE"),
|
||||
("realm", "Realm"),
|
||||
("nystromformer", "Nystromformer"),
|
||||
("xglm", "XGLM"),
|
||||
("imagegpt", "ImageGPT"),
|
||||
("qdqbert", "QDQBert"),
|
||||
("vision-encoder-decoder", "Vision Encoder decoder"),
|
||||
("trocr", "TrOCR"),
|
||||
("fnet", "FNet"),
|
||||
("segformer", "SegFormer"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoder"),
|
||||
("perceiver", "Perceiver"),
|
||||
("gptj", "GPT-J"),
|
||||
("beit", "BEiT"),
|
||||
("plbart", "PLBart"),
|
||||
("rembert", "RemBERT"),
|
||||
("layoutlmv2", "LayoutLMv2"),
|
||||
("visual_bert", "VisualBert"),
|
||||
("canine", "Canine"),
|
||||
("roformer", "RoFormer"),
|
||||
("clip", "CLIP"),
|
||||
("flava", "Flava"),
|
||||
("bigbird_pegasus", "BigBirdPegasus"),
|
||||
("deit", "DeiT"),
|
||||
("luke", "LUKE"),
|
||||
("detr", "DETR"),
|
||||
("gpt_neo", "GPT Neo"),
|
||||
("big_bird", "BigBird"),
|
||||
("speech_to_text_2", "Speech2Text2"),
|
||||
("speech_to_text", "Speech2Text"),
|
||||
("vit", "ViT"),
|
||||
("wav2vec2", "Wav2Vec2"),
|
||||
("m2m_100", "M2M100"),
|
||||
("convbert", "ConvBERT"),
|
||||
("led", "LED"),
|
||||
("blenderbot-small", "BlenderbotSmall"),
|
||||
("retribert", "RetriBERT"),
|
||||
("ibert", "I-BERT"),
|
||||
("t5", "T5"),
|
||||
("mobilebert", "MobileBERT"),
|
||||
("distilbert", "DistilBERT"),
|
||||
("albert", "ALBERT"),
|
||||
("bert-generation", "Bert Generation"),
|
||||
("camembert", "CamemBERT"),
|
||||
("xlm-roberta", "XLM-RoBERTa"),
|
||||
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
|
||||
("pegasus", "Pegasus"),
|
||||
("blenderbot", "Blenderbot"),
|
||||
("marian", "Marian"),
|
||||
("mbart", "mBART"),
|
||||
("megatron-bert", "MegatronBert"),
|
||||
("bart", "BART"),
|
||||
("opt", "OPT"),
|
||||
("reformer", "Reformer"),
|
||||
("longformer", "Longformer"),
|
||||
("roberta", "RoBERTa"),
|
||||
("flaubert", "FlauBERT"),
|
||||
("fsmt", "FairSeq Machine-Translation"),
|
||||
("squeezebert", "SqueezeBERT"),
|
||||
("bert", "BERT"),
|
||||
("openai-gpt", "OpenAI GPT"),
|
||||
("gpt2", "OpenAI GPT-2"),
|
||||
("transfo-xl", "Transformer-XL"),
|
||||
("xlnet", "XLNet"),
|
||||
("xlm", "XLM"),
|
||||
("ctrl", "CTRL"),
|
||||
("electra", "ELECTRA"),
|
||||
("encoder-decoder", "Encoder decoder"),
|
||||
("speech-encoder-decoder", "Speech Encoder decoder"),
|
||||
("vision-encoder-decoder", "Vision Encoder decoder"),
|
||||
("funnel", "Funnel Transformer"),
|
||||
("lxmert", "LXMERT"),
|
||||
("deberta-v2", "DeBERTa-v2"),
|
||||
("deberta", "DeBERTa"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("dpr", "DPR"),
|
||||
("rag", "RAG"),
|
||||
("xlm-prophetnet", "XLMProphetNet"),
|
||||
("prophetnet", "ProphetNet"),
|
||||
("mt5", "mT5"),
|
||||
("mpnet", "MPNet"),
|
||||
("tapas", "TAPAS"),
|
||||
("hubert", "Hubert"),
|
||||
("barthez", "BARThez"),
|
||||
("phobert", "PhoBERT"),
|
||||
("bartpho", "BARTpho"),
|
||||
("cpm", "CPM"),
|
||||
("bertweet", "Bertweet"),
|
||||
("beit", "BEiT"),
|
||||
("bert", "BERT"),
|
||||
("bert-generation", "Bert Generation"),
|
||||
("bert-japanese", "BertJapanese"),
|
||||
("byt5", "ByT5"),
|
||||
("mbart50", "mBART-50"),
|
||||
("splinter", "Splinter"),
|
||||
("sew-d", "SEW-D"),
|
||||
("sew", "SEW"),
|
||||
("unispeech-sat", "UniSpeechSat"),
|
||||
("unispeech", "UniSpeech"),
|
||||
("wavlm", "WavLM"),
|
||||
("bertweet", "Bertweet"),
|
||||
("big_bird", "BigBird"),
|
||||
("bigbird_pegasus", "BigBirdPegasus"),
|
||||
("blenderbot", "Blenderbot"),
|
||||
("blenderbot-small", "BlenderbotSmall"),
|
||||
("bort", "BORT"),
|
||||
("dialogpt", "DialoGPT"),
|
||||
("xls_r", "XLS-R"),
|
||||
("t5v1.1", "T5v1.1"),
|
||||
("herbert", "HerBERT"),
|
||||
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
|
||||
("megatron_gpt2", "MegatronGPT2"),
|
||||
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
|
||||
("mluke", "mLUKE"),
|
||||
("layoutxlm", "LayoutXLM"),
|
||||
("byt5", "ByT5"),
|
||||
("camembert", "CamemBERT"),
|
||||
("canine", "Canine"),
|
||||
("clip", "CLIP"),
|
||||
("convbert", "ConvBERT"),
|
||||
("convnext", "ConvNext"),
|
||||
("cpm", "CPM"),
|
||||
("ctrl", "CTRL"),
|
||||
("data2vec-audio", "Data2VecAudio"),
|
||||
("data2vec-text", "Data2VecText"),
|
||||
("data2vec-vision", "Data2VecVision"),
|
||||
("deberta", "DeBERTa"),
|
||||
("deberta-v2", "DeBERTa-v2"),
|
||||
("decision_transformer", "Decision Transformer"),
|
||||
("deit", "DeiT"),
|
||||
("detr", "DETR"),
|
||||
("dialogpt", "DialoGPT"),
|
||||
("distilbert", "DistilBERT"),
|
||||
("dit", "DiT"),
|
||||
("dpr", "DPR"),
|
||||
("dpt", "DPT"),
|
||||
("electra", "ELECTRA"),
|
||||
("encoder-decoder", "Encoder decoder"),
|
||||
("flaubert", "FlauBERT"),
|
||||
("flava", "Flava"),
|
||||
("fnet", "FNet"),
|
||||
("fsmt", "FairSeq Machine-Translation"),
|
||||
("funnel", "Funnel Transformer"),
|
||||
("glpn", "GLPN"),
|
||||
("gpt2", "OpenAI GPT-2"),
|
||||
("gpt_neo", "GPT Neo"),
|
||||
("gptj", "GPT-J"),
|
||||
("herbert", "HerBERT"),
|
||||
("hubert", "Hubert"),
|
||||
("ibert", "I-BERT"),
|
||||
("imagegpt", "ImageGPT"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("layoutlmv2", "LayoutLMv2"),
|
||||
("layoutxlm", "LayoutXLM"),
|
||||
("led", "LED"),
|
||||
("longformer", "Longformer"),
|
||||
("luke", "LUKE"),
|
||||
("lxmert", "LXMERT"),
|
||||
("m2m_100", "M2M100"),
|
||||
("marian", "Marian"),
|
||||
("maskformer", "MaskFormer"),
|
||||
("mbart", "mBART"),
|
||||
("mbart50", "mBART-50"),
|
||||
("megatron-bert", "MegatronBert"),
|
||||
("megatron_gpt2", "MegatronGPT2"),
|
||||
("mluke", "mLUKE"),
|
||||
("mobilebert", "MobileBERT"),
|
||||
("mpnet", "MPNet"),
|
||||
("mt5", "mT5"),
|
||||
("nystromformer", "Nystromformer"),
|
||||
("openai-gpt", "OpenAI GPT"),
|
||||
("opt", "OPT"),
|
||||
("pegasus", "Pegasus"),
|
||||
("perceiver", "Perceiver"),
|
||||
("phobert", "PhoBERT"),
|
||||
("plbart", "PLBart"),
|
||||
("poolformer", "PoolFormer"),
|
||||
("prophetnet", "ProphetNet"),
|
||||
("qdqbert", "QDQBert"),
|
||||
("rag", "RAG"),
|
||||
("realm", "Realm"),
|
||||
("reformer", "Reformer"),
|
||||
("regnet", "RegNet"),
|
||||
("rembert", "RemBERT"),
|
||||
("resnet", "ResNet"),
|
||||
("retribert", "RetriBERT"),
|
||||
("roberta", "RoBERTa"),
|
||||
("roformer", "RoFormer"),
|
||||
("segformer", "SegFormer"),
|
||||
("sew", "SEW"),
|
||||
("sew-d", "SEW-D"),
|
||||
("speech-encoder-decoder", "Speech Encoder decoder"),
|
||||
("speech_to_text", "Speech2Text"),
|
||||
("speech_to_text_2", "Speech2Text2"),
|
||||
("splinter", "Splinter"),
|
||||
("squeezebert", "SqueezeBERT"),
|
||||
("swin", "Swin"),
|
||||
("t5", "T5"),
|
||||
("t5v1.1", "T5v1.1"),
|
||||
("tapas", "TAPAS"),
|
||||
("tapex", "TAPEX"),
|
||||
("transfo-xl", "Transformer-XL"),
|
||||
("trocr", "TrOCR"),
|
||||
("unispeech", "UniSpeech"),
|
||||
("unispeech-sat", "UniSpeechSat"),
|
||||
("van", "VAN"),
|
||||
("vilt", "ViLT"),
|
||||
("vision-encoder-decoder", "Vision Encoder decoder"),
|
||||
("vision-encoder-decoder", "Vision Encoder decoder"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoder"),
|
||||
("visual_bert", "VisualBert"),
|
||||
("vit", "ViT"),
|
||||
("vit_mae", "ViTMAE"),
|
||||
("wav2vec2", "Wav2Vec2"),
|
||||
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
|
||||
("wavlm", "WavLM"),
|
||||
("xglm", "XGLM"),
|
||||
("xlm", "XLM"),
|
||||
("xlm-prophetnet", "XLMProphetNet"),
|
||||
("xlm-roberta", "XLM-RoBERTa"),
|
||||
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
|
||||
("xlnet", "XLNet"),
|
||||
("xls_r", "XLS-R"),
|
||||
("xlsr_wav2vec2", "XLSR-Wav2Vec2"),
|
||||
("yolos", "YOLOS"),
|
||||
("yoso", "YOSO"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -703,9 +702,10 @@ class AutoConfig:
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
for pattern, config_class in CONFIG_MAPPING.items():
|
||||
# We go from longer names to shorter names to catch roberta before bert (for instance)
|
||||
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
|
||||
if pattern in str(pretrained_model_name_or_path):
|
||||
return config_class.from_dict(config_dict, **kwargs)
|
||||
return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
||||
|
@ -38,30 +38,30 @@ logger = logging.get_logger(__name__)
|
||||
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("beit", "BeitFeatureExtractor"),
|
||||
("detr", "DetrFeatureExtractor"),
|
||||
("deit", "DeiTFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("vit", "ViTFeatureExtractor"),
|
||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||
("detr", "DetrFeatureExtractor"),
|
||||
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
||||
("clip", "CLIPFeatureExtractor"),
|
||||
("flava", "FlavaFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
("swin", "ViTFeatureExtractor"),
|
||||
("vit_mae", "ViTFeatureExtractor"),
|
||||
("segformer", "SegformerFeatureExtractor"),
|
||||
("convnext", "ConvNextFeatureExtractor"),
|
||||
("van", "ConvNextFeatureExtractor"),
|
||||
("resnet", "ConvNextFeatureExtractor"),
|
||||
("regnet", "ConvNextFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
("maskformer", "MaskFormerFeatureExtractor"),
|
||||
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
|
||||
("data2vec-vision", "BeitFeatureExtractor"),
|
||||
("deit", "DeiTFeatureExtractor"),
|
||||
("detr", "DetrFeatureExtractor"),
|
||||
("detr", "DetrFeatureExtractor"),
|
||||
("dpt", "DPTFeatureExtractor"),
|
||||
("flava", "FlavaFeatureExtractor"),
|
||||
("glpn", "GLPNFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
||||
("maskformer", "MaskFormerFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
("regnet", "ConvNextFeatureExtractor"),
|
||||
("resnet", "ConvNextFeatureExtractor"),
|
||||
("segformer", "SegformerFeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("swin", "ViTFeatureExtractor"),
|
||||
("van", "ConvNextFeatureExtractor"),
|
||||
("vit", "ViTFeatureExtractor"),
|
||||
("vit_mae", "ViTFeatureExtractor"),
|
||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||
("yolos", "YolosFeatureExtractor"),
|
||||
]
|
||||
)
|
||||
@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str):
|
||||
module_name = model_type_to_module_name(module_name)
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattr(module, class_name)
|
||||
break
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
|
||||
if getattr(extractor, "__name__", None) == class_name:
|
||||
|
@ -28,257 +28,257 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("yolos", "YolosModel"),
|
||||
("dpt", "DPTModel"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("glpn", "GLPNModel"),
|
||||
("maskformer", "MaskFormerModel"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("convnext", "ConvNextModel"),
|
||||
("van", "VanModel"),
|
||||
("resnet", "ResNetModel"),
|
||||
("regnet", "RegNetModel"),
|
||||
("yoso", "YosoModel"),
|
||||
("swin", "SwinModel"),
|
||||
("vilt", "ViltModel"),
|
||||
("vit_mae", "ViTMAEModel"),
|
||||
("nystromformer", "NystromformerModel"),
|
||||
("xglm", "XGLMModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("qdqbert", "QDQBertModel"),
|
||||
("fnet", "FNetModel"),
|
||||
("segformer", "SegformerModel"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
|
||||
("perceiver", "PerceiverModel"),
|
||||
("gptj", "GPTJModel"),
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("plbart", "PLBartModel"),
|
||||
("beit", "BeitModel"),
|
||||
("data2vec-vision", "Data2VecVisionModel"),
|
||||
("rembert", "RemBertModel"),
|
||||
("visual_bert", "VisualBertModel"),
|
||||
("canine", "CanineModel"),
|
||||
("roformer", "RoFormerModel"),
|
||||
("clip", "CLIPModel"),
|
||||
("flava", "FlavaModel"),
|
||||
("bigbird_pegasus", "BigBirdPegasusModel"),
|
||||
("deit", "DeiTModel"),
|
||||
("luke", "LukeModel"),
|
||||
("detr", "DetrModel"),
|
||||
("gpt_neo", "GPTNeoModel"),
|
||||
("big_bird", "BigBirdModel"),
|
||||
("speech_to_text", "Speech2TextModel"),
|
||||
("vit", "ViTModel"),
|
||||
("wav2vec2", "Wav2Vec2Model"),
|
||||
("unispeech-sat", "UniSpeechSatModel"),
|
||||
("wavlm", "WavLMModel"),
|
||||
("unispeech", "UniSpeechModel"),
|
||||
("hubert", "HubertModel"),
|
||||
("m2m_100", "M2M100Model"),
|
||||
("convbert", "ConvBertModel"),
|
||||
("led", "LEDModel"),
|
||||
("blenderbot-small", "BlenderbotSmallModel"),
|
||||
("retribert", "RetriBertModel"),
|
||||
("mt5", "MT5Model"),
|
||||
("t5", "T5Model"),
|
||||
("pegasus", "PegasusModel"),
|
||||
("marian", "MarianModel"),
|
||||
("mbart", "MBartModel"),
|
||||
("blenderbot", "BlenderbotModel"),
|
||||
("distilbert", "DistilBertModel"),
|
||||
("albert", "AlbertModel"),
|
||||
("camembert", "CamembertModel"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
||||
("xlm-roberta", "XLMRobertaModel"),
|
||||
("bart", "BartModel"),
|
||||
("opt", "OPTModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("roberta", "RobertaModel"),
|
||||
("data2vec-text", "Data2VecTextModel"),
|
||||
("data2vec-audio", "Data2VecAudioModel"),
|
||||
("layoutlm", "LayoutLMModel"),
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("beit", "BeitModel"),
|
||||
("bert", "BertModel"),
|
||||
("openai-gpt", "OpenAIGPTModel"),
|
||||
("gpt2", "GPT2Model"),
|
||||
("megatron-bert", "MegatronBertModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
("transfo-xl", "TransfoXLModel"),
|
||||
("xlnet", "XLNetModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("fsmt", "FSMTModel"),
|
||||
("xlm", "XLMModel"),
|
||||
("ctrl", "CTRLModel"),
|
||||
("electra", "ElectraModel"),
|
||||
("reformer", "ReformerModel"),
|
||||
("funnel", ("FunnelModel", "FunnelBaseModel")),
|
||||
("lxmert", "LxmertModel"),
|
||||
("bert-generation", "BertGenerationEncoder"),
|
||||
("big_bird", "BigBirdModel"),
|
||||
("bigbird_pegasus", "BigBirdPegasusModel"),
|
||||
("blenderbot", "BlenderbotModel"),
|
||||
("blenderbot-small", "BlenderbotSmallModel"),
|
||||
("camembert", "CamembertModel"),
|
||||
("canine", "CanineModel"),
|
||||
("clip", "CLIPModel"),
|
||||
("convbert", "ConvBertModel"),
|
||||
("convnext", "ConvNextModel"),
|
||||
("ctrl", "CTRLModel"),
|
||||
("data2vec-audio", "Data2VecAudioModel"),
|
||||
("data2vec-text", "Data2VecTextModel"),
|
||||
("data2vec-vision", "Data2VecVisionModel"),
|
||||
("deberta", "DebertaModel"),
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
|
||||
("deit", "DeiTModel"),
|
||||
("detr", "DetrModel"),
|
||||
("distilbert", "DistilBertModel"),
|
||||
("dpr", "DPRQuestionEncoder"),
|
||||
("xlm-prophetnet", "XLMProphetNetModel"),
|
||||
("prophetnet", "ProphetNetModel"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("tapas", "TapasModel"),
|
||||
("dpt", "DPTModel"),
|
||||
("electra", "ElectraModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
("flava", "FlavaModel"),
|
||||
("fnet", "FNetModel"),
|
||||
("fsmt", "FSMTModel"),
|
||||
("funnel", ("FunnelModel", "FunnelBaseModel")),
|
||||
("glpn", "GLPNModel"),
|
||||
("gpt2", "GPT2Model"),
|
||||
("gpt_neo", "GPTNeoModel"),
|
||||
("gptj", "GPTJModel"),
|
||||
("hubert", "HubertModel"),
|
||||
("ibert", "IBertModel"),
|
||||
("splinter", "SplinterModel"),
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("layoutlm", "LayoutLMModel"),
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("led", "LEDModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("luke", "LukeModel"),
|
||||
("lxmert", "LxmertModel"),
|
||||
("m2m_100", "M2M100Model"),
|
||||
("marian", "MarianModel"),
|
||||
("maskformer", "MaskFormerModel"),
|
||||
("mbart", "MBartModel"),
|
||||
("megatron-bert", "MegatronBertModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
("mpnet", "MPNetModel"),
|
||||
("mt5", "MT5Model"),
|
||||
("nystromformer", "NystromformerModel"),
|
||||
("openai-gpt", "OpenAIGPTModel"),
|
||||
("opt", "OPTModel"),
|
||||
("pegasus", "PegasusModel"),
|
||||
("perceiver", "PerceiverModel"),
|
||||
("plbart", "PLBartModel"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("prophetnet", "ProphetNetModel"),
|
||||
("qdqbert", "QDQBertModel"),
|
||||
("reformer", "ReformerModel"),
|
||||
("regnet", "RegNetModel"),
|
||||
("rembert", "RemBertModel"),
|
||||
("resnet", "ResNetModel"),
|
||||
("retribert", "RetriBertModel"),
|
||||
("roberta", "RobertaModel"),
|
||||
("roformer", "RoFormerModel"),
|
||||
("segformer", "SegformerModel"),
|
||||
("sew", "SEWModel"),
|
||||
("sew-d", "SEWDModel"),
|
||||
("speech_to_text", "Speech2TextModel"),
|
||||
("splinter", "SplinterModel"),
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("swin", "SwinModel"),
|
||||
("t5", "T5Model"),
|
||||
("tapas", "TapasModel"),
|
||||
("transfo-xl", "TransfoXLModel"),
|
||||
("unispeech", "UniSpeechModel"),
|
||||
("unispeech-sat", "UniSpeechSatModel"),
|
||||
("van", "VanModel"),
|
||||
("vilt", "ViltModel"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
|
||||
("visual_bert", "VisualBertModel"),
|
||||
("vit", "ViTModel"),
|
||||
("vit_mae", "ViTMAEModel"),
|
||||
("wav2vec2", "Wav2Vec2Model"),
|
||||
("wavlm", "WavLMModel"),
|
||||
("xglm", "XGLMModel"),
|
||||
("xlm", "XLMModel"),
|
||||
("xlm-prophetnet", "XLMProphetNetModel"),
|
||||
("xlm-roberta", "XLMRobertaModel"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
||||
("xlnet", "XLNetModel"),
|
||||
("yolos", "YolosModel"),
|
||||
("yoso", "YosoModel"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
("flava", "FlavaForPreTraining"),
|
||||
("vit_mae", "ViTMAEForPreTraining"),
|
||||
("fnet", "FNetForPreTraining"),
|
||||
("visual_bert", "VisualBertForPreTraining"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("retribert", "RetriBertModel"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("albert", "AlbertForPreTraining"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("bert", "BertForPreTraining"),
|
||||
("big_bird", "BigBirdForPreTraining"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
("flaubert", "FlaubertWithLMHeadModel"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("electra", "ElectraForPreTraining"),
|
||||
("lxmert", "LxmertForPreTraining"),
|
||||
("funnel", "FunnelForPreTraining"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("deberta", "DebertaForMaskedLM"),
|
||||
("deberta-v2", "DebertaV2ForMaskedLM"),
|
||||
("wav2vec2", "Wav2Vec2ForPreTraining"),
|
||||
("unispeech-sat", "UniSpeechSatForPreTraining"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("electra", "ElectraForPreTraining"),
|
||||
("flaubert", "FlaubertWithLMHeadModel"),
|
||||
("flava", "FlavaForPreTraining"),
|
||||
("fnet", "FNetForPreTraining"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("funnel", "FunnelForPreTraining"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("lxmert", "LxmertForPreTraining"),
|
||||
("megatron-bert", "MegatronBertForPreTraining"),
|
||||
("mobilebert", "MobileBertForPreTraining"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
("retribert", "RetriBertModel"),
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("unispeech", "UniSpeechForPreTraining"),
|
||||
("unispeech-sat", "UniSpeechSatForPreTraining"),
|
||||
("visual_bert", "VisualBertForPreTraining"),
|
||||
("vit_mae", "ViTMAEForPreTraining"),
|
||||
("wav2vec2", "Wav2Vec2ForPreTraining"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
("yoso", "YosoForMaskedLM"),
|
||||
("nystromformer", "NystromformerForMaskedLM"),
|
||||
("plbart", "PLBartForConditionalGeneration"),
|
||||
("qdqbert", "QDQBertForMaskedLM"),
|
||||
("fnet", "FNetForMaskedLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("rembert", "RemBertForMaskedLM"),
|
||||
("roformer", "RoFormerForMaskedLM"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
||||
("gpt_neo", "GPTNeoForCausalLM"),
|
||||
("big_bird", "BigBirdForMaskedLM"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("convbert", "ConvBertForMaskedLM"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("albert", "AlbertForMaskedLM"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||
("marian", "MarianMTModel"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("bert", "BertForMaskedLM"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
("flaubert", "FlaubertWithLMHeadModel"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("big_bird", "BigBirdForMaskedLM"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
||||
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("convbert", "ConvBertForMaskedLM"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("electra", "ElectraForMaskedLM"),
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
("reformer", "ReformerModelWithLMHead"),
|
||||
("funnel", "FunnelForMaskedLM"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("deberta", "DebertaForMaskedLM"),
|
||||
("deberta-v2", "DebertaV2ForMaskedLM"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("electra", "ElectraForMaskedLM"),
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
("flaubert", "FlaubertWithLMHeadModel"),
|
||||
("fnet", "FNetForMaskedLM"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("funnel", "FunnelForMaskedLM"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("gpt_neo", "GPTNeoForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("marian", "MarianMTModel"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("nystromformer", "NystromformerForMaskedLM"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
("plbart", "PLBartForConditionalGeneration"),
|
||||
("qdqbert", "QDQBertForMaskedLM"),
|
||||
("reformer", "ReformerModelWithLMHead"),
|
||||
("rembert", "RemBertForMaskedLM"),
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("roformer", "RoFormerForMaskedLM"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
("yoso", "YosoForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("xglm", "XGLMForCausalLM"),
|
||||
("plbart", "PLBartForCausalLM"),
|
||||
("qdqbert", "QDQBertLMHeadModel"),
|
||||
("trocr", "TrOCRForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("rembert", "RemBertForCausalLM"),
|
||||
("roformer", "RoFormerForCausalLM"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
|
||||
("gpt_neo", "GPTNeoForCausalLM"),
|
||||
("big_bird", "BigBirdForCausalLM"),
|
||||
("camembert", "CamembertForCausalLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
|
||||
("xlm-roberta", "XLMRobertaForCausalLM"),
|
||||
("roberta", "RobertaForCausalLM"),
|
||||
("bert", "BertLMHeadModel"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("reformer", "ReformerModelWithLMHead"),
|
||||
("bert-generation", "BertGenerationDecoder"),
|
||||
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
|
||||
("prophetnet", "ProphetNetForCausalLM"),
|
||||
("bart", "BartForCausalLM"),
|
||||
("opt", "OPTForCausalLM"),
|
||||
("mbart", "MBartForCausalLM"),
|
||||
("pegasus", "PegasusForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
("bert", "BertLMHeadModel"),
|
||||
("bert-generation", "BertGenerationDecoder"),
|
||||
("big_bird", "BigBirdForCausalLM"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
|
||||
("blenderbot", "BlenderbotForCausalLM"),
|
||||
("blenderbot-small", "BlenderbotSmallForCausalLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
||||
("camembert", "CamembertForCausalLM"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("data2vec-text", "Data2VecTextForCausalLM"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
("gpt_neo", "GPTNeoForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
("mbart", "MBartForCausalLM"),
|
||||
("megatron-bert", "MegatronBertForCausalLM"),
|
||||
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
||||
("opt", "OPTForCausalLM"),
|
||||
("pegasus", "PegasusForCausalLM"),
|
||||
("plbart", "PLBartForCausalLM"),
|
||||
("prophetnet", "ProphetNetForCausalLM"),
|
||||
("qdqbert", "QDQBertLMHeadModel"),
|
||||
("reformer", "ReformerModelWithLMHead"),
|
||||
("rembert", "RemBertForCausalLM"),
|
||||
("roberta", "RobertaForCausalLM"),
|
||||
("roformer", "RoFormerForCausalLM"),
|
||||
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("trocr", "TrOCRForCausalLM"),
|
||||
("xglm", "XGLMForCausalLM"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
|
||||
("xlm-roberta", "XLMRobertaForCausalLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("vit", "ViTForMaskedImageModeling"),
|
||||
("deit", "DeiTForMaskedImageModeling"),
|
||||
("swin", "SwinForMaskedImageModeling"),
|
||||
("vit", "ViTForMaskedImageModeling"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -293,11 +293,10 @@ MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image Classification mapping
|
||||
("vit", "ViTForImageClassification"),
|
||||
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
|
||||
("beit", "BeitForImageClassification"),
|
||||
("convnext", "ConvNextForImageClassification"),
|
||||
("data2vec-vision", "Data2VecVisionForImageClassification"),
|
||||
("segformer", "SegformerForImageClassification"),
|
||||
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
|
||||
("imagegpt", "ImageGPTForImageClassification"),
|
||||
(
|
||||
"perceiver",
|
||||
@ -307,12 +306,13 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
"PerceiverForImageClassificationConvProcessing",
|
||||
),
|
||||
),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("convnext", "ConvNextForImageClassification"),
|
||||
("van", "VanForImageClassification"),
|
||||
("resnet", "ResNetForImageClassification"),
|
||||
("regnet", "RegNetForImageClassification"),
|
||||
("poolformer", "PoolFormerForImageClassification"),
|
||||
("regnet", "RegNetForImageClassification"),
|
||||
("resnet", "ResNetForImageClassification"),
|
||||
("segformer", "SegformerForImageClassification"),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("van", "VanForImageClassification"),
|
||||
("vit", "ViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -329,8 +329,8 @@ MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Semantic Segmentation mapping
|
||||
("beit", "BeitForSemanticSegmentation"),
|
||||
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
|
||||
("segformer", "SegformerForSemanticSegmentation"),
|
||||
("dpt", "DPTForSemanticSegmentation"),
|
||||
("segformer", "SegformerForSemanticSegmentation"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -350,72 +350,71 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
("yoso", "YosoForMaskedLM"),
|
||||
("albert", "AlbertForMaskedLM"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("bert", "BertForMaskedLM"),
|
||||
("big_bird", "BigBirdForMaskedLM"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("convbert", "ConvBertForMaskedLM"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("deberta", "DebertaForMaskedLM"),
|
||||
("deberta-v2", "DebertaV2ForMaskedLM"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("electra", "ElectraForMaskedLM"),
|
||||
("flaubert", "FlaubertWithLMHeadModel"),
|
||||
("fnet", "FNetForMaskedLM"),
|
||||
("funnel", "FunnelForMaskedLM"),
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("mbart", "MBartForConditionalGeneration"),
|
||||
("megatron-bert", "MegatronBertForMaskedLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("nystromformer", "NystromformerForMaskedLM"),
|
||||
("perceiver", "PerceiverForMaskedLM"),
|
||||
("qdqbert", "QDQBertForMaskedLM"),
|
||||
("fnet", "FNetForMaskedLM"),
|
||||
("rembert", "RemBertForMaskedLM"),
|
||||
("roformer", "RoFormerForMaskedLM"),
|
||||
("big_bird", "BigBirdForMaskedLM"),
|
||||
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
||||
("convbert", "ConvBertForMaskedLM"),
|
||||
("layoutlm", "LayoutLMForMaskedLM"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("albert", "AlbertForMaskedLM"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("mbart", "MBartForConditionalGeneration"),
|
||||
("camembert", "CamembertForMaskedLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||
("longformer", "LongformerForMaskedLM"),
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("bert", "BertForMaskedLM"),
|
||||
("megatron-bert", "MegatronBertForMaskedLM"),
|
||||
("mobilebert", "MobileBertForMaskedLM"),
|
||||
("flaubert", "FlaubertWithLMHeadModel"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("electra", "ElectraForMaskedLM"),
|
||||
("reformer", "ReformerForMaskedLM"),
|
||||
("funnel", "FunnelForMaskedLM"),
|
||||
("mpnet", "MPNetForMaskedLM"),
|
||||
("rembert", "RemBertForMaskedLM"),
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("roformer", "RoFormerForMaskedLM"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("deberta", "DebertaForMaskedLM"),
|
||||
("deberta-v2", "DebertaV2ForMaskedLM"),
|
||||
("ibert", "IBertForMaskedLM"),
|
||||
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
||||
("yoso", "YosoForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Object Detection mapping
|
||||
("yolos", "YolosForObjectDetection"),
|
||||
("detr", "DetrForObjectDetection"),
|
||||
("yolos", "YolosForObjectDetection"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
("tapex", "BartForConditionalGeneration"),
|
||||
("plbart", "PLBartForConditionalGeneration"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("blenderbot", "BlenderbotForConditionalGeneration"),
|
||||
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
|
||||
("mt5", "MT5ForConditionalGeneration"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("pegasus", "PegasusForConditionalGeneration"),
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("led", "LEDForConditionalGeneration"),
|
||||
("m2m_100", "M2M100ForConditionalGeneration"),
|
||||
("marian", "MarianMTModel"),
|
||||
("mbart", "MBartForConditionalGeneration"),
|
||||
("blenderbot", "BlenderbotForConditionalGeneration"),
|
||||
("bart", "BartForConditionalGeneration"),
|
||||
("fsmt", "FSMTForConditionalGeneration"),
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
|
||||
("mt5", "MT5ForConditionalGeneration"),
|
||||
("pegasus", "PegasusForConditionalGeneration"),
|
||||
("plbart", "PLBartForConditionalGeneration"),
|
||||
("prophetnet", "ProphetNetForConditionalGeneration"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -429,98 +428,97 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("tapex", "BartForSequenceClassification"),
|
||||
("yoso", "YosoForSequenceClassification"),
|
||||
("nystromformer", "NystromformerForSequenceClassification"),
|
||||
("plbart", "PLBartForSequenceClassification"),
|
||||
("perceiver", "PerceiverForSequenceClassification"),
|
||||
("qdqbert", "QDQBertForSequenceClassification"),
|
||||
("fnet", "FNetForSequenceClassification"),
|
||||
("gptj", "GPTJForSequenceClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
|
||||
("rembert", "RemBertForSequenceClassification"),
|
||||
("canine", "CanineForSequenceClassification"),
|
||||
("roformer", "RoFormerForSequenceClassification"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
|
||||
("big_bird", "BigBirdForSequenceClassification"),
|
||||
("convbert", "ConvBertForSequenceClassification"),
|
||||
("led", "LEDForSequenceClassification"),
|
||||
("distilbert", "DistilBertForSequenceClassification"),
|
||||
("albert", "AlbertForSequenceClassification"),
|
||||
("camembert", "CamembertForSequenceClassification"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
|
||||
("xlm-roberta", "XLMRobertaForSequenceClassification"),
|
||||
("mbart", "MBartForSequenceClassification"),
|
||||
("bart", "BartForSequenceClassification"),
|
||||
("longformer", "LongformerForSequenceClassification"),
|
||||
("roberta", "RobertaForSequenceClassification"),
|
||||
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
||||
("squeezebert", "SqueezeBertForSequenceClassification"),
|
||||
("layoutlm", "LayoutLMForSequenceClassification"),
|
||||
("bert", "BertForSequenceClassification"),
|
||||
("xlnet", "XLNetForSequenceClassification"),
|
||||
("megatron-bert", "MegatronBertForSequenceClassification"),
|
||||
("mobilebert", "MobileBertForSequenceClassification"),
|
||||
("flaubert", "FlaubertForSequenceClassification"),
|
||||
("xlm", "XLMForSequenceClassification"),
|
||||
("electra", "ElectraForSequenceClassification"),
|
||||
("funnel", "FunnelForSequenceClassification"),
|
||||
("big_bird", "BigBirdForSequenceClassification"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
|
||||
("camembert", "CamembertForSequenceClassification"),
|
||||
("canine", "CanineForSequenceClassification"),
|
||||
("convbert", "ConvBertForSequenceClassification"),
|
||||
("ctrl", "CTRLForSequenceClassification"),
|
||||
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
||||
("deberta", "DebertaForSequenceClassification"),
|
||||
("deberta-v2", "DebertaV2ForSequenceClassification"),
|
||||
("distilbert", "DistilBertForSequenceClassification"),
|
||||
("electra", "ElectraForSequenceClassification"),
|
||||
("flaubert", "FlaubertForSequenceClassification"),
|
||||
("fnet", "FNetForSequenceClassification"),
|
||||
("funnel", "FunnelForSequenceClassification"),
|
||||
("gpt2", "GPT2ForSequenceClassification"),
|
||||
("gpt_neo", "GPTNeoForSequenceClassification"),
|
||||
("openai-gpt", "OpenAIGPTForSequenceClassification"),
|
||||
("reformer", "ReformerForSequenceClassification"),
|
||||
("ctrl", "CTRLForSequenceClassification"),
|
||||
("transfo-xl", "TransfoXLForSequenceClassification"),
|
||||
("mpnet", "MPNetForSequenceClassification"),
|
||||
("tapas", "TapasForSequenceClassification"),
|
||||
("gptj", "GPTJForSequenceClassification"),
|
||||
("ibert", "IBertForSequenceClassification"),
|
||||
("layoutlm", "LayoutLMForSequenceClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
|
||||
("led", "LEDForSequenceClassification"),
|
||||
("longformer", "LongformerForSequenceClassification"),
|
||||
("mbart", "MBartForSequenceClassification"),
|
||||
("megatron-bert", "MegatronBertForSequenceClassification"),
|
||||
("mobilebert", "MobileBertForSequenceClassification"),
|
||||
("mpnet", "MPNetForSequenceClassification"),
|
||||
("nystromformer", "NystromformerForSequenceClassification"),
|
||||
("openai-gpt", "OpenAIGPTForSequenceClassification"),
|
||||
("perceiver", "PerceiverForSequenceClassification"),
|
||||
("plbart", "PLBartForSequenceClassification"),
|
||||
("qdqbert", "QDQBertForSequenceClassification"),
|
||||
("reformer", "ReformerForSequenceClassification"),
|
||||
("rembert", "RemBertForSequenceClassification"),
|
||||
("roberta", "RobertaForSequenceClassification"),
|
||||
("roformer", "RoFormerForSequenceClassification"),
|
||||
("squeezebert", "SqueezeBertForSequenceClassification"),
|
||||
("tapas", "TapasForSequenceClassification"),
|
||||
("transfo-xl", "TransfoXLForSequenceClassification"),
|
||||
("xlm", "XLMForSequenceClassification"),
|
||||
("xlm-roberta", "XLMRobertaForSequenceClassification"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
|
||||
("xlnet", "XLNetForSequenceClassification"),
|
||||
("yoso", "YosoForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("yoso", "YosoForQuestionAnswering"),
|
||||
("nystromformer", "NystromformerForQuestionAnswering"),
|
||||
("qdqbert", "QDQBertForQuestionAnswering"),
|
||||
("fnet", "FNetForQuestionAnswering"),
|
||||
("gptj", "GPTJForQuestionAnswering"),
|
||||
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
||||
("rembert", "RemBertForQuestionAnswering"),
|
||||
("canine", "CanineForQuestionAnswering"),
|
||||
("roformer", "RoFormerForQuestionAnswering"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
|
||||
("big_bird", "BigBirdForQuestionAnswering"),
|
||||
("convbert", "ConvBertForQuestionAnswering"),
|
||||
("led", "LEDForQuestionAnswering"),
|
||||
("distilbert", "DistilBertForQuestionAnswering"),
|
||||
("albert", "AlbertForQuestionAnswering"),
|
||||
("camembert", "CamembertForQuestionAnswering"),
|
||||
("bart", "BartForQuestionAnswering"),
|
||||
("mbart", "MBartForQuestionAnswering"),
|
||||
("longformer", "LongformerForQuestionAnswering"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
|
||||
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
|
||||
("roberta", "RobertaForQuestionAnswering"),
|
||||
("squeezebert", "SqueezeBertForQuestionAnswering"),
|
||||
("bert", "BertForQuestionAnswering"),
|
||||
("xlnet", "XLNetForQuestionAnsweringSimple"),
|
||||
("flaubert", "FlaubertForQuestionAnsweringSimple"),
|
||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||
("xlm", "XLMForQuestionAnsweringSimple"),
|
||||
("electra", "ElectraForQuestionAnswering"),
|
||||
("reformer", "ReformerForQuestionAnswering"),
|
||||
("funnel", "FunnelForQuestionAnswering"),
|
||||
("lxmert", "LxmertForQuestionAnswering"),
|
||||
("mpnet", "MPNetForQuestionAnswering"),
|
||||
("big_bird", "BigBirdForQuestionAnswering"),
|
||||
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
|
||||
("camembert", "CamembertForQuestionAnswering"),
|
||||
("canine", "CanineForQuestionAnswering"),
|
||||
("convbert", "ConvBertForQuestionAnswering"),
|
||||
("data2vec-text", "Data2VecTextForQuestionAnswering"),
|
||||
("deberta", "DebertaForQuestionAnswering"),
|
||||
("deberta-v2", "DebertaV2ForQuestionAnswering"),
|
||||
("distilbert", "DistilBertForQuestionAnswering"),
|
||||
("electra", "ElectraForQuestionAnswering"),
|
||||
("flaubert", "FlaubertForQuestionAnsweringSimple"),
|
||||
("fnet", "FNetForQuestionAnswering"),
|
||||
("funnel", "FunnelForQuestionAnswering"),
|
||||
("gptj", "GPTJForQuestionAnswering"),
|
||||
("ibert", "IBertForQuestionAnswering"),
|
||||
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
||||
("led", "LEDForQuestionAnswering"),
|
||||
("longformer", "LongformerForQuestionAnswering"),
|
||||
("lxmert", "LxmertForQuestionAnswering"),
|
||||
("mbart", "MBartForQuestionAnswering"),
|
||||
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
||||
("mobilebert", "MobileBertForQuestionAnswering"),
|
||||
("mpnet", "MPNetForQuestionAnswering"),
|
||||
("nystromformer", "NystromformerForQuestionAnswering"),
|
||||
("qdqbert", "QDQBertForQuestionAnswering"),
|
||||
("reformer", "ReformerForQuestionAnswering"),
|
||||
("rembert", "RemBertForQuestionAnswering"),
|
||||
("roberta", "RobertaForQuestionAnswering"),
|
||||
("roformer", "RoFormerForQuestionAnswering"),
|
||||
("splinter", "SplinterForQuestionAnswering"),
|
||||
("data2vec-text", "Data2VecTextForQuestionAnswering"),
|
||||
("squeezebert", "SqueezeBertForQuestionAnswering"),
|
||||
("xlm", "XLMForQuestionAnsweringSimple"),
|
||||
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
|
||||
("xlnet", "XLNetForQuestionAnsweringSimple"),
|
||||
("yoso", "YosoForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -534,132 +532,132 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("yoso", "YosoForTokenClassification"),
|
||||
("nystromformer", "NystromformerForTokenClassification"),
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
("fnet", "FNetForTokenClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("rembert", "RemBertForTokenClassification"),
|
||||
("canine", "CanineForTokenClassification"),
|
||||
("roformer", "RoFormerForTokenClassification"),
|
||||
("big_bird", "BigBirdForTokenClassification"),
|
||||
("convbert", "ConvBertForTokenClassification"),
|
||||
("layoutlm", "LayoutLMForTokenClassification"),
|
||||
("distilbert", "DistilBertForTokenClassification"),
|
||||
("camembert", "CamembertForTokenClassification"),
|
||||
("flaubert", "FlaubertForTokenClassification"),
|
||||
("xlm", "XLMForTokenClassification"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
|
||||
("xlm-roberta", "XLMRobertaForTokenClassification"),
|
||||
("longformer", "LongformerForTokenClassification"),
|
||||
("roberta", "RobertaForTokenClassification"),
|
||||
("squeezebert", "SqueezeBertForTokenClassification"),
|
||||
("bert", "BertForTokenClassification"),
|
||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||
("mobilebert", "MobileBertForTokenClassification"),
|
||||
("xlnet", "XLNetForTokenClassification"),
|
||||
("albert", "AlbertForTokenClassification"),
|
||||
("electra", "ElectraForTokenClassification"),
|
||||
("funnel", "FunnelForTokenClassification"),
|
||||
("mpnet", "MPNetForTokenClassification"),
|
||||
("bert", "BertForTokenClassification"),
|
||||
("big_bird", "BigBirdForTokenClassification"),
|
||||
("camembert", "CamembertForTokenClassification"),
|
||||
("canine", "CanineForTokenClassification"),
|
||||
("convbert", "ConvBertForTokenClassification"),
|
||||
("data2vec-text", "Data2VecTextForTokenClassification"),
|
||||
("deberta", "DebertaForTokenClassification"),
|
||||
("deberta-v2", "DebertaV2ForTokenClassification"),
|
||||
("distilbert", "DistilBertForTokenClassification"),
|
||||
("electra", "ElectraForTokenClassification"),
|
||||
("flaubert", "FlaubertForTokenClassification"),
|
||||
("fnet", "FNetForTokenClassification"),
|
||||
("funnel", "FunnelForTokenClassification"),
|
||||
("gpt2", "GPT2ForTokenClassification"),
|
||||
("ibert", "IBertForTokenClassification"),
|
||||
("data2vec-text", "Data2VecTextForTokenClassification"),
|
||||
("layoutlm", "LayoutLMForTokenClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("longformer", "LongformerForTokenClassification"),
|
||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||
("mobilebert", "MobileBertForTokenClassification"),
|
||||
("mpnet", "MPNetForTokenClassification"),
|
||||
("nystromformer", "NystromformerForTokenClassification"),
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
("rembert", "RemBertForTokenClassification"),
|
||||
("roberta", "RobertaForTokenClassification"),
|
||||
("roformer", "RoFormerForTokenClassification"),
|
||||
("squeezebert", "SqueezeBertForTokenClassification"),
|
||||
("xlm", "XLMForTokenClassification"),
|
||||
("xlm-roberta", "XLMRobertaForTokenClassification"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
|
||||
("xlnet", "XLNetForTokenClassification"),
|
||||
("yoso", "YosoForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
("yoso", "YosoForMultipleChoice"),
|
||||
("nystromformer", "NystromformerForMultipleChoice"),
|
||||
("qdqbert", "QDQBertForMultipleChoice"),
|
||||
("fnet", "FNetForMultipleChoice"),
|
||||
("rembert", "RemBertForMultipleChoice"),
|
||||
("canine", "CanineForMultipleChoice"),
|
||||
("roformer", "RoFormerForMultipleChoice"),
|
||||
("big_bird", "BigBirdForMultipleChoice"),
|
||||
("convbert", "ConvBertForMultipleChoice"),
|
||||
("camembert", "CamembertForMultipleChoice"),
|
||||
("electra", "ElectraForMultipleChoice"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
|
||||
("xlm-roberta", "XLMRobertaForMultipleChoice"),
|
||||
("longformer", "LongformerForMultipleChoice"),
|
||||
("roberta", "RobertaForMultipleChoice"),
|
||||
("data2vec-text", "Data2VecTextForMultipleChoice"),
|
||||
("squeezebert", "SqueezeBertForMultipleChoice"),
|
||||
("albert", "AlbertForMultipleChoice"),
|
||||
("bert", "BertForMultipleChoice"),
|
||||
("big_bird", "BigBirdForMultipleChoice"),
|
||||
("camembert", "CamembertForMultipleChoice"),
|
||||
("canine", "CanineForMultipleChoice"),
|
||||
("convbert", "ConvBertForMultipleChoice"),
|
||||
("data2vec-text", "Data2VecTextForMultipleChoice"),
|
||||
("deberta-v2", "DebertaV2ForMultipleChoice"),
|
||||
("distilbert", "DistilBertForMultipleChoice"),
|
||||
("electra", "ElectraForMultipleChoice"),
|
||||
("flaubert", "FlaubertForMultipleChoice"),
|
||||
("fnet", "FNetForMultipleChoice"),
|
||||
("funnel", "FunnelForMultipleChoice"),
|
||||
("ibert", "IBertForMultipleChoice"),
|
||||
("longformer", "LongformerForMultipleChoice"),
|
||||
("megatron-bert", "MegatronBertForMultipleChoice"),
|
||||
("mobilebert", "MobileBertForMultipleChoice"),
|
||||
("xlnet", "XLNetForMultipleChoice"),
|
||||
("albert", "AlbertForMultipleChoice"),
|
||||
("xlm", "XLMForMultipleChoice"),
|
||||
("flaubert", "FlaubertForMultipleChoice"),
|
||||
("funnel", "FunnelForMultipleChoice"),
|
||||
("mpnet", "MPNetForMultipleChoice"),
|
||||
("ibert", "IBertForMultipleChoice"),
|
||||
("deberta-v2", "DebertaV2ForMultipleChoice"),
|
||||
("nystromformer", "NystromformerForMultipleChoice"),
|
||||
("qdqbert", "QDQBertForMultipleChoice"),
|
||||
("rembert", "RemBertForMultipleChoice"),
|
||||
("roberta", "RobertaForMultipleChoice"),
|
||||
("roformer", "RoFormerForMultipleChoice"),
|
||||
("squeezebert", "SqueezeBertForMultipleChoice"),
|
||||
("xlm", "XLMForMultipleChoice"),
|
||||
("xlm-roberta", "XLMRobertaForMultipleChoice"),
|
||||
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
|
||||
("xlnet", "XLNetForMultipleChoice"),
|
||||
("yoso", "YosoForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("qdqbert", "QDQBertForNextSentencePrediction"),
|
||||
("bert", "BertForNextSentencePrediction"),
|
||||
("fnet", "FNetForNextSentencePrediction"),
|
||||
("megatron-bert", "MegatronBertForNextSentencePrediction"),
|
||||
("mobilebert", "MobileBertForNextSentencePrediction"),
|
||||
("qdqbert", "QDQBertForNextSentencePrediction"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
||||
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
|
||||
("unispeech", "UniSpeechForSequenceClassification"),
|
||||
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
|
||||
("hubert", "HubertForSequenceClassification"),
|
||||
("sew", "SEWForSequenceClassification"),
|
||||
("sew-d", "SEWDForSequenceClassification"),
|
||||
("unispeech", "UniSpeechForSequenceClassification"),
|
||||
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
|
||||
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
||||
("wavlm", "WavLMForSequenceClassification"),
|
||||
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Connectionist temporal classification (CTC) mapping
|
||||
("wav2vec2", "Wav2Vec2ForCTC"),
|
||||
("unispeech-sat", "UniSpeechSatForCTC"),
|
||||
("unispeech", "UniSpeechForCTC"),
|
||||
("data2vec-audio", "Data2VecAudioForCTC"),
|
||||
("hubert", "HubertForCTC"),
|
||||
("sew", "SEWForCTC"),
|
||||
("sew-d", "SEWDForCTC"),
|
||||
("unispeech", "UniSpeechForCTC"),
|
||||
("unispeech-sat", "UniSpeechSatForCTC"),
|
||||
("wav2vec2", "Wav2Vec2ForCTC"),
|
||||
("wavlm", "WavLMForCTC"),
|
||||
("data2vec-audio", "Data2VecAudioForCTC"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
|
||||
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
|
||||
("wavlm", "WavLMForAudioFrameClassification"),
|
||||
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
|
||||
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
|
||||
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
|
||||
("wavlm", "WavLMForAudioFrameClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForXVector"),
|
||||
("unispeech-sat", "UniSpeechSatForXVector"),
|
||||
("wavlm", "WavLMForXVector"),
|
||||
("data2vec-audio", "Data2VecAudioForXVector"),
|
||||
("unispeech-sat", "UniSpeechSatForXVector"),
|
||||
("wav2vec2", "Wav2Vec2ForXVector"),
|
||||
("wavlm", "WavLMForXVector"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -28,31 +28,31 @@ logger = logging.get_logger(__name__)
|
||||
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("xglm", "FlaxXGLMModel"),
|
||||
("blenderbot-small", "FlaxBlenderbotSmallModel"),
|
||||
("pegasus", "FlaxPegasusModel"),
|
||||
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
||||
("distilbert", "FlaxDistilBertModel"),
|
||||
("albert", "FlaxAlbertModel"),
|
||||
("roberta", "FlaxRobertaModel"),
|
||||
("xlm-roberta", "FlaxXLMRobertaModel"),
|
||||
("bert", "FlaxBertModel"),
|
||||
("beit", "FlaxBeitModel"),
|
||||
("big_bird", "FlaxBigBirdModel"),
|
||||
("bart", "FlaxBartModel"),
|
||||
("beit", "FlaxBeitModel"),
|
||||
("bert", "FlaxBertModel"),
|
||||
("big_bird", "FlaxBigBirdModel"),
|
||||
("blenderbot", "FlaxBlenderbotModel"),
|
||||
("blenderbot-small", "FlaxBlenderbotSmallModel"),
|
||||
("clip", "FlaxCLIPModel"),
|
||||
("distilbert", "FlaxDistilBertModel"),
|
||||
("electra", "FlaxElectraModel"),
|
||||
("gpt2", "FlaxGPT2Model"),
|
||||
("gpt_neo", "FlaxGPTNeoModel"),
|
||||
("gptj", "FlaxGPTJModel"),
|
||||
("electra", "FlaxElectraModel"),
|
||||
("clip", "FlaxCLIPModel"),
|
||||
("vit", "FlaxViTModel"),
|
||||
("mbart", "FlaxMBartModel"),
|
||||
("t5", "FlaxT5Model"),
|
||||
("mt5", "FlaxMT5Model"),
|
||||
("wav2vec2", "FlaxWav2Vec2Model"),
|
||||
("marian", "FlaxMarianModel"),
|
||||
("blenderbot", "FlaxBlenderbotModel"),
|
||||
("mbart", "FlaxMBartModel"),
|
||||
("mt5", "FlaxMT5Model"),
|
||||
("pegasus", "FlaxPegasusModel"),
|
||||
("roberta", "FlaxRobertaModel"),
|
||||
("roformer", "FlaxRoFormerModel"),
|
||||
("t5", "FlaxT5Model"),
|
||||
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
||||
("vit", "FlaxViTModel"),
|
||||
("wav2vec2", "FlaxWav2Vec2Model"),
|
||||
("xglm", "FlaxXGLMModel"),
|
||||
("xlm-roberta", "FlaxXLMRobertaModel"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -60,56 +60,56 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
("albert", "FlaxAlbertForPreTraining"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("bert", "FlaxBertForPreTraining"),
|
||||
("big_bird", "FlaxBigBirdForPreTraining"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("electra", "FlaxElectraForPreTraining"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("roformer", "FlaxRoFormerForMaskedLM"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
("distilbert", "FlaxDistilBertForMaskedLM"),
|
||||
("albert", "FlaxAlbertForMaskedLM"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("bert", "FlaxBertForMaskedLM"),
|
||||
("big_bird", "FlaxBigBirdForMaskedLM"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("distilbert", "FlaxDistilBertForMaskedLM"),
|
||||
("electra", "FlaxElectraForMaskedLM"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("roformer", "FlaxRoFormerForMaskedLM"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
|
||||
("pegasus", "FlaxPegasusForConditionalGeneration"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
("marian", "FlaxMarianMTModel"),
|
||||
("encoder-decoder", "FlaxEncoderDecoderModel"),
|
||||
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
|
||||
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
|
||||
("encoder-decoder", "FlaxEncoderDecoderModel"),
|
||||
("marian", "FlaxMarianMTModel"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
("pegasus", "FlaxPegasusForConditionalGeneration"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
("vit", "FlaxViTForImageClassification"),
|
||||
("beit", "FlaxBeitForImageClassification"),
|
||||
("vit", "FlaxViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -122,75 +122,75 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("bart", "FlaxBartForCausalLM"),
|
||||
("bert", "FlaxBertForCausalLM"),
|
||||
("big_bird", "FlaxBigBirdForCausalLM"),
|
||||
("electra", "FlaxElectraForCausalLM"),
|
||||
("gpt2", "FlaxGPT2LMHeadModel"),
|
||||
("gpt_neo", "FlaxGPTNeoForCausalLM"),
|
||||
("gptj", "FlaxGPTJForCausalLM"),
|
||||
("xglm", "FlaxXGLMForCausalLM"),
|
||||
("bart", "FlaxBartForCausalLM"),
|
||||
("bert", "FlaxBertForCausalLM"),
|
||||
("roberta", "FlaxRobertaForCausalLM"),
|
||||
("big_bird", "FlaxBigBirdForCausalLM"),
|
||||
("electra", "FlaxElectraForCausalLM"),
|
||||
("xglm", "FlaxXGLMForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
||||
("albert", "FlaxAlbertForSequenceClassification"),
|
||||
("roberta", "FlaxRobertaForSequenceClassification"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
|
||||
("bart", "FlaxBartForSequenceClassification"),
|
||||
("bert", "FlaxBertForSequenceClassification"),
|
||||
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
||||
("bart", "FlaxBartForSequenceClassification"),
|
||||
("distilbert", "FlaxDistilBertForSequenceClassification"),
|
||||
("electra", "FlaxElectraForSequenceClassification"),
|
||||
("mbart", "FlaxMBartForSequenceClassification"),
|
||||
("roberta", "FlaxRobertaForSequenceClassification"),
|
||||
("roformer", "FlaxRoFormerForSequenceClassification"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
||||
("albert", "FlaxAlbertForQuestionAnswering"),
|
||||
("roberta", "FlaxRobertaForQuestionAnswering"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
|
||||
("bart", "FlaxBartForQuestionAnswering"),
|
||||
("bert", "FlaxBertForQuestionAnswering"),
|
||||
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
||||
("bart", "FlaxBartForQuestionAnswering"),
|
||||
("distilbert", "FlaxDistilBertForQuestionAnswering"),
|
||||
("electra", "FlaxElectraForQuestionAnswering"),
|
||||
("mbart", "FlaxMBartForQuestionAnswering"),
|
||||
("roberta", "FlaxRobertaForQuestionAnswering"),
|
||||
("roformer", "FlaxRoFormerForQuestionAnswering"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("distilbert", "FlaxDistilBertForTokenClassification"),
|
||||
("albert", "FlaxAlbertForTokenClassification"),
|
||||
("roberta", "FlaxRobertaForTokenClassification"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
|
||||
("bert", "FlaxBertForTokenClassification"),
|
||||
("big_bird", "FlaxBigBirdForTokenClassification"),
|
||||
("distilbert", "FlaxDistilBertForTokenClassification"),
|
||||
("electra", "FlaxElectraForTokenClassification"),
|
||||
("roberta", "FlaxRobertaForTokenClassification"),
|
||||
("roformer", "FlaxRoFormerForTokenClassification"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
||||
("albert", "FlaxAlbertForMultipleChoice"),
|
||||
("roberta", "FlaxRobertaForMultipleChoice"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
|
||||
("bert", "FlaxBertForMultipleChoice"),
|
||||
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
||||
("distilbert", "FlaxDistilBertForMultipleChoice"),
|
||||
("electra", "FlaxElectraForMultipleChoice"),
|
||||
("roberta", "FlaxRobertaForMultipleChoice"),
|
||||
("roformer", "FlaxRoFormerForMultipleChoice"),
|
||||
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -29,142 +29,142 @@ logger = logging.get_logger(__name__)
|
||||
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("speech_to_text", "TFSpeech2TextModel"),
|
||||
("clip", "TFCLIPModel"),
|
||||
("deberta-v2", "TFDebertaV2Model"),
|
||||
("deberta", "TFDebertaModel"),
|
||||
("rembert", "TFRemBertModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
("convnext", "TFConvNextModel"),
|
||||
("data2vec-vision", "TFData2VecVisionModel"),
|
||||
("led", "TFLEDModel"),
|
||||
("lxmert", "TFLxmertModel"),
|
||||
("mt5", "TFMT5Model"),
|
||||
("t5", "TFT5Model"),
|
||||
("distilbert", "TFDistilBertModel"),
|
||||
("albert", "TFAlbertModel"),
|
||||
("bart", "TFBartModel"),
|
||||
("camembert", "TFCamembertModel"),
|
||||
("xlm-roberta", "TFXLMRobertaModel"),
|
||||
("longformer", "TFLongformerModel"),
|
||||
("roberta", "TFRobertaModel"),
|
||||
("layoutlm", "TFLayoutLMModel"),
|
||||
("bert", "TFBertModel"),
|
||||
("openai-gpt", "TFOpenAIGPTModel"),
|
||||
("gpt2", "TFGPT2Model"),
|
||||
("gptj", "TFGPTJModel"),
|
||||
("mobilebert", "TFMobileBertModel"),
|
||||
("transfo-xl", "TFTransfoXLModel"),
|
||||
("xlnet", "TFXLNetModel"),
|
||||
("flaubert", "TFFlaubertModel"),
|
||||
("xlm", "TFXLMModel"),
|
||||
("ctrl", "TFCTRLModel"),
|
||||
("electra", "TFElectraModel"),
|
||||
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
||||
("dpr", "TFDPRQuestionEncoder"),
|
||||
("mpnet", "TFMPNetModel"),
|
||||
("tapas", "TFTapasModel"),
|
||||
("mbart", "TFMBartModel"),
|
||||
("marian", "TFMarianModel"),
|
||||
("pegasus", "TFPegasusModel"),
|
||||
("blenderbot", "TFBlenderbotModel"),
|
||||
("blenderbot-small", "TFBlenderbotSmallModel"),
|
||||
("camembert", "TFCamembertModel"),
|
||||
("clip", "TFCLIPModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
("convnext", "TFConvNextModel"),
|
||||
("ctrl", "TFCTRLModel"),
|
||||
("data2vec-vision", "TFData2VecVisionModel"),
|
||||
("deberta", "TFDebertaModel"),
|
||||
("deberta-v2", "TFDebertaV2Model"),
|
||||
("distilbert", "TFDistilBertModel"),
|
||||
("dpr", "TFDPRQuestionEncoder"),
|
||||
("electra", "TFElectraModel"),
|
||||
("flaubert", "TFFlaubertModel"),
|
||||
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
||||
("gpt2", "TFGPT2Model"),
|
||||
("gptj", "TFGPTJModel"),
|
||||
("hubert", "TFHubertModel"),
|
||||
("layoutlm", "TFLayoutLMModel"),
|
||||
("led", "TFLEDModel"),
|
||||
("longformer", "TFLongformerModel"),
|
||||
("lxmert", "TFLxmertModel"),
|
||||
("marian", "TFMarianModel"),
|
||||
("mbart", "TFMBartModel"),
|
||||
("mobilebert", "TFMobileBertModel"),
|
||||
("mpnet", "TFMPNetModel"),
|
||||
("mt5", "TFMT5Model"),
|
||||
("openai-gpt", "TFOpenAIGPTModel"),
|
||||
("pegasus", "TFPegasusModel"),
|
||||
("rembert", "TFRemBertModel"),
|
||||
("roberta", "TFRobertaModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("speech_to_text", "TFSpeech2TextModel"),
|
||||
("t5", "TFT5Model"),
|
||||
("tapas", "TFTapasModel"),
|
||||
("transfo-xl", "TFTransfoXLModel"),
|
||||
("vit", "TFViTModel"),
|
||||
("vit_mae", "TFViTMAEModel"),
|
||||
("wav2vec2", "TFWav2Vec2Model"),
|
||||
("hubert", "TFHubertModel"),
|
||||
("xlm", "TFXLMModel"),
|
||||
("xlm-roberta", "TFXLMRobertaModel"),
|
||||
("xlnet", "TFXLNetModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
("lxmert", "TFLxmertForPreTraining"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("albert", "TFAlbertForPreTraining"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("bert", "TFBertForPreTraining"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("mobilebert", "TFMobileBertForPreTraining"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("electra", "TFElectraForPreTraining"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("funnel", "TFFunnelForPreTraining"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("lxmert", "TFLxmertForPreTraining"),
|
||||
("mobilebert", "TFMobileBertForPreTraining"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("vit_mae", "TFViTMAEForPreTraining"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
("led", "TFLEDForConditionalGeneration"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("albert", "TFAlbertForMaskedLM"),
|
||||
("marian", "TFMarianMTModel"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("longformer", "TFLongformerForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("bert", "TFBertForMaskedLM"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("electra", "TFElectraForMaskedLM"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("funnel", "TFFunnelForMaskedLM"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("gptj", "TFGPTJForCausalLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("led", "TFLEDForConditionalGeneration"),
|
||||
("longformer", "TFLongformerForMaskedLM"),
|
||||
("marian", "TFMarianMTModel"),
|
||||
("mobilebert", "TFMobileBertForMaskedLM"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("electra", "TFElectraForMaskedLM"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("funnel", "TFFunnelForMaskedLM"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("camembert", "TFCamembertForCausalLM"),
|
||||
("rembert", "TFRemBertForCausalLM"),
|
||||
("roformer", "TFRoFormerForCausalLM"),
|
||||
("roberta", "TFRobertaForCausalLM"),
|
||||
("bert", "TFBertLMHeadModel"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("camembert", "TFCamembertForCausalLM"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("gptj", "TFGPTJForCausalLM"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("rembert", "TFRemBertForCausalLM"),
|
||||
("roberta", "TFRobertaForCausalLM"),
|
||||
("roformer", "TFRoFormerForCausalLM"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
("vit", "TFViTForImageClassification"),
|
||||
("convnext", "TFConvNextForImageClassification"),
|
||||
("data2vec-vision", "TFData2VecVisionForImageClassification"),
|
||||
("vit", "TFViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -177,42 +177,42 @@ TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
||||
("deberta", "TFDebertaForMaskedLM"),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("albert", "TFAlbertForMaskedLM"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("longformer", "TFLongformerForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("bert", "TFBertForMaskedLM"),
|
||||
("mobilebert", "TFMobileBertForMaskedLM"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
("deberta", "TFDebertaForMaskedLM"),
|
||||
("deberta-v2", "TFDebertaV2ForMaskedLM"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("electra", "TFElectraForMaskedLM"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("funnel", "TFFunnelForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("longformer", "TFLongformerForMaskedLM"),
|
||||
("mobilebert", "TFMobileBertForMaskedLM"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
("led", "TFLEDForConditionalGeneration"),
|
||||
("mt5", "TFMT5ForConditionalGeneration"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("marian", "TFMarianMTModel"),
|
||||
("mbart", "TFMBartForConditionalGeneration"),
|
||||
("pegasus", "TFPegasusForConditionalGeneration"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
||||
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
("encoder-decoder", "TFEncoderDecoderModel"),
|
||||
("led", "TFLEDForConditionalGeneration"),
|
||||
("marian", "TFMarianMTModel"),
|
||||
("mbart", "TFMBartForConditionalGeneration"),
|
||||
("mt5", "TFMT5ForConditionalGeneration"),
|
||||
("pegasus", "TFPegasusForConditionalGeneration"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -225,58 +225,58 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
||||
("deberta", "TFDebertaForSequenceClassification"),
|
||||
("rembert", "TFRemBertForSequenceClassification"),
|
||||
("roformer", "TFRoFormerForSequenceClassification"),
|
||||
("convbert", "TFConvBertForSequenceClassification"),
|
||||
("distilbert", "TFDistilBertForSequenceClassification"),
|
||||
("albert", "TFAlbertForSequenceClassification"),
|
||||
("camembert", "TFCamembertForSequenceClassification"),
|
||||
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
|
||||
("longformer", "TFLongformerForSequenceClassification"),
|
||||
("roberta", "TFRobertaForSequenceClassification"),
|
||||
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
||||
("bert", "TFBertForSequenceClassification"),
|
||||
("xlnet", "TFXLNetForSequenceClassification"),
|
||||
("mobilebert", "TFMobileBertForSequenceClassification"),
|
||||
("flaubert", "TFFlaubertForSequenceClassification"),
|
||||
("xlm", "TFXLMForSequenceClassification"),
|
||||
("camembert", "TFCamembertForSequenceClassification"),
|
||||
("convbert", "TFConvBertForSequenceClassification"),
|
||||
("ctrl", "TFCTRLForSequenceClassification"),
|
||||
("deberta", "TFDebertaForSequenceClassification"),
|
||||
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
|
||||
("distilbert", "TFDistilBertForSequenceClassification"),
|
||||
("electra", "TFElectraForSequenceClassification"),
|
||||
("tapas", "TFTapasForSequenceClassification"),
|
||||
("flaubert", "TFFlaubertForSequenceClassification"),
|
||||
("funnel", "TFFunnelForSequenceClassification"),
|
||||
("gpt2", "TFGPT2ForSequenceClassification"),
|
||||
("gptj", "TFGPTJForSequenceClassification"),
|
||||
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
||||
("longformer", "TFLongformerForSequenceClassification"),
|
||||
("mobilebert", "TFMobileBertForSequenceClassification"),
|
||||
("mpnet", "TFMPNetForSequenceClassification"),
|
||||
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
|
||||
("rembert", "TFRemBertForSequenceClassification"),
|
||||
("roberta", "TFRobertaForSequenceClassification"),
|
||||
("roformer", "TFRoFormerForSequenceClassification"),
|
||||
("tapas", "TFTapasForSequenceClassification"),
|
||||
("transfo-xl", "TFTransfoXLForSequenceClassification"),
|
||||
("ctrl", "TFCTRLForSequenceClassification"),
|
||||
("xlm", "TFXLMForSequenceClassification"),
|
||||
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
|
||||
("xlnet", "TFXLNetForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
|
||||
("deberta", "TFDebertaForQuestionAnswering"),
|
||||
("rembert", "TFRemBertForQuestionAnswering"),
|
||||
("roformer", "TFRoFormerForQuestionAnswering"),
|
||||
("convbert", "TFConvBertForQuestionAnswering"),
|
||||
("distilbert", "TFDistilBertForQuestionAnswering"),
|
||||
("albert", "TFAlbertForQuestionAnswering"),
|
||||
("camembert", "TFCamembertForQuestionAnswering"),
|
||||
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
|
||||
("longformer", "TFLongformerForQuestionAnswering"),
|
||||
("roberta", "TFRobertaForQuestionAnswering"),
|
||||
("bert", "TFBertForQuestionAnswering"),
|
||||
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
||||
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
||||
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
||||
("xlm", "TFXLMForQuestionAnsweringSimple"),
|
||||
("camembert", "TFCamembertForQuestionAnswering"),
|
||||
("convbert", "TFConvBertForQuestionAnswering"),
|
||||
("deberta", "TFDebertaForQuestionAnswering"),
|
||||
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
|
||||
("distilbert", "TFDistilBertForQuestionAnswering"),
|
||||
("electra", "TFElectraForQuestionAnswering"),
|
||||
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
||||
("funnel", "TFFunnelForQuestionAnswering"),
|
||||
("gptj", "TFGPTJForQuestionAnswering"),
|
||||
("longformer", "TFLongformerForQuestionAnswering"),
|
||||
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
||||
("mpnet", "TFMPNetForQuestionAnswering"),
|
||||
("rembert", "TFRemBertForQuestionAnswering"),
|
||||
("roberta", "TFRobertaForQuestionAnswering"),
|
||||
("roformer", "TFRoFormerForQuestionAnswering"),
|
||||
("xlm", "TFXLMForQuestionAnsweringSimple"),
|
||||
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
|
||||
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -291,49 +291,49 @@ TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
||||
("deberta", "TFDebertaForTokenClassification"),
|
||||
("rembert", "TFRemBertForTokenClassification"),
|
||||
("roformer", "TFRoFormerForTokenClassification"),
|
||||
("convbert", "TFConvBertForTokenClassification"),
|
||||
("distilbert", "TFDistilBertForTokenClassification"),
|
||||
("albert", "TFAlbertForTokenClassification"),
|
||||
("bert", "TFBertForTokenClassification"),
|
||||
("camembert", "TFCamembertForTokenClassification"),
|
||||
("convbert", "TFConvBertForTokenClassification"),
|
||||
("deberta", "TFDebertaForTokenClassification"),
|
||||
("deberta-v2", "TFDebertaV2ForTokenClassification"),
|
||||
("distilbert", "TFDistilBertForTokenClassification"),
|
||||
("electra", "TFElectraForTokenClassification"),
|
||||
("flaubert", "TFFlaubertForTokenClassification"),
|
||||
("funnel", "TFFunnelForTokenClassification"),
|
||||
("layoutlm", "TFLayoutLMForTokenClassification"),
|
||||
("longformer", "TFLongformerForTokenClassification"),
|
||||
("mobilebert", "TFMobileBertForTokenClassification"),
|
||||
("mpnet", "TFMPNetForTokenClassification"),
|
||||
("rembert", "TFRemBertForTokenClassification"),
|
||||
("roberta", "TFRobertaForTokenClassification"),
|
||||
("roformer", "TFRoFormerForTokenClassification"),
|
||||
("xlm", "TFXLMForTokenClassification"),
|
||||
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
|
||||
("longformer", "TFLongformerForTokenClassification"),
|
||||
("roberta", "TFRobertaForTokenClassification"),
|
||||
("layoutlm", "TFLayoutLMForTokenClassification"),
|
||||
("bert", "TFBertForTokenClassification"),
|
||||
("mobilebert", "TFMobileBertForTokenClassification"),
|
||||
("xlnet", "TFXLNetForTokenClassification"),
|
||||
("electra", "TFElectraForTokenClassification"),
|
||||
("funnel", "TFFunnelForTokenClassification"),
|
||||
("mpnet", "TFMPNetForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
("rembert", "TFRemBertForMultipleChoice"),
|
||||
("roformer", "TFRoFormerForMultipleChoice"),
|
||||
("convbert", "TFConvBertForMultipleChoice"),
|
||||
("albert", "TFAlbertForMultipleChoice"),
|
||||
("bert", "TFBertForMultipleChoice"),
|
||||
("camembert", "TFCamembertForMultipleChoice"),
|
||||
("convbert", "TFConvBertForMultipleChoice"),
|
||||
("distilbert", "TFDistilBertForMultipleChoice"),
|
||||
("electra", "TFElectraForMultipleChoice"),
|
||||
("flaubert", "TFFlaubertForMultipleChoice"),
|
||||
("funnel", "TFFunnelForMultipleChoice"),
|
||||
("longformer", "TFLongformerForMultipleChoice"),
|
||||
("mobilebert", "TFMobileBertForMultipleChoice"),
|
||||
("mpnet", "TFMPNetForMultipleChoice"),
|
||||
("rembert", "TFRemBertForMultipleChoice"),
|
||||
("roberta", "TFRobertaForMultipleChoice"),
|
||||
("roformer", "TFRoFormerForMultipleChoice"),
|
||||
("xlm", "TFXLMForMultipleChoice"),
|
||||
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
|
||||
("longformer", "TFLongformerForMultipleChoice"),
|
||||
("roberta", "TFRobertaForMultipleChoice"),
|
||||
("bert", "TFBertForMultipleChoice"),
|
||||
("distilbert", "TFDistilBertForMultipleChoice"),
|
||||
("mobilebert", "TFMobileBertForMultipleChoice"),
|
||||
("xlnet", "TFXLNetForMultipleChoice"),
|
||||
("flaubert", "TFFlaubertForMultipleChoice"),
|
||||
("albert", "TFAlbertForMultipleChoice"),
|
||||
("electra", "TFElectraForMultipleChoice"),
|
||||
("funnel", "TFFunnelForMultipleChoice"),
|
||||
("mpnet", "TFMPNetForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -41,17 +41,17 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("flava", "FLAVAProcessor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutxlm", "LayoutXLMProcessor"),
|
||||
("sew", "Wav2Vec2Processor"),
|
||||
("sew-d", "Wav2Vec2Processor"),
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("trocr", "TrOCRProcessor"),
|
||||
("wav2vec2", "Wav2Vec2Processor"),
|
||||
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
||||
("unispeech", "Wav2Vec2Processor"),
|
||||
("unispeech-sat", "Wav2Vec2Processor"),
|
||||
("sew", "Wav2Vec2Processor"),
|
||||
("sew-d", "Wav2Vec2Processor"),
|
||||
("vilt", "ViltProcessor"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
||||
("wav2vec2", "Wav2Vec2Processor"),
|
||||
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
|
||||
("wavlm", "Wav2Vec2Processor"),
|
||||
]
|
||||
)
|
||||
@ -65,7 +65,10 @@ def processor_class_from_name(class_name: str):
|
||||
module_name = model_type_to_module_name(module_name)
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattr(module, class_name)
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
for processor in PROCESSOR_MAPPING._extra_content.values():
|
||||
if getattr(processor, "__name__", None) == class_name:
|
||||
|
@ -46,27 +46,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"t5",
|
||||
(
|
||||
"T5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mt5",
|
||||
(
|
||||
"MT5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"albert",
|
||||
(
|
||||
@ -74,6 +53,30 @@ else:
|
||||
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
||||
(
|
||||
"barthez",
|
||||
(
|
||||
"BarthezTokenizer" if is_sentencepiece_available() else None,
|
||||
"BarthezTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("bartpho", ("BartphoTokenizer", None)),
|
||||
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
||||
("bertweet", ("BertweetTokenizer", None)),
|
||||
(
|
||||
"big_bird",
|
||||
(
|
||||
"BigBirdTokenizer" if is_sentencepiece_available() else None,
|
||||
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
|
||||
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
|
||||
("byt5", ("ByT5Tokenizer", None)),
|
||||
(
|
||||
"camembert",
|
||||
(
|
||||
@ -81,76 +84,24 @@ else:
|
||||
"CamembertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("canine", ("CanineTokenizer", None)),
|
||||
(
|
||||
"pegasus",
|
||||
"clip",
|
||||
(
|
||||
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
||||
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
||||
"CLIPTokenizer",
|
||||
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"mbart",
|
||||
"cpm",
|
||||
(
|
||||
"MBartTokenizer" if is_sentencepiece_available() else None,
|
||||
"MBartTokenizerFast" if is_tokenizers_available() else None,
|
||||
"CpmTokenizer" if is_sentencepiece_available() else None,
|
||||
"CpmTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"xlm-roberta",
|
||||
(
|
||||
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
||||
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
|
||||
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
|
||||
("tapex", ("TapexTokenizer", None)),
|
||||
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"reformer",
|
||||
(
|
||||
"ReformerTokenizer" if is_sentencepiece_available() else None,
|
||||
"ReformerTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"dpr",
|
||||
(
|
||||
"DPRQuestionEncoderTokenizer",
|
||||
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"squeezebert",
|
||||
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("opt", ("GPT2Tokenizer", None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
(
|
||||
"xlnet",
|
||||
(
|
||||
"XLNetTokenizer" if is_sentencepiece_available() else None,
|
||||
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("flaubert", ("FlaubertTokenizer", None)),
|
||||
("xlm", ("XLMTokenizer", None)),
|
||||
("ctrl", ("CTRLTokenizer", None)),
|
||||
("fsmt", ("FSMTTokenizer", None)),
|
||||
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"deberta-v2",
|
||||
@ -159,51 +110,39 @@ else:
|
||||
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("rag", ("RagTokenizer", None)),
|
||||
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"big_bird",
|
||||
"dpr",
|
||||
(
|
||||
"BigBirdTokenizer" if is_sentencepiece_available() else None,
|
||||
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
|
||||
"DPRQuestionEncoderTokenizer",
|
||||
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("flaubert", ("FlaubertTokenizer", None)),
|
||||
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("fsmt", ("FSMTTokenizer", None)),
|
||||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("luke", ("LukeTokenizer", None)),
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("canine", ("CanineTokenizer", None)),
|
||||
("bertweet", ("BertweetTokenizer", None)),
|
||||
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
||||
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
|
||||
("byt5", ("ByT5Tokenizer", None)),
|
||||
(
|
||||
"cpm",
|
||||
(
|
||||
"CpmTokenizer" if is_sentencepiece_available() else None,
|
||||
"CpmTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
("bartpho", ("BartphoTokenizer", None)),
|
||||
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("luke", ("LukeTokenizer", None)),
|
||||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
(
|
||||
"barthez",
|
||||
"mbart",
|
||||
(
|
||||
"BarthezTokenizer" if is_sentencepiece_available() else None,
|
||||
"BarthezTokenizerFast" if is_tokenizers_available() else None,
|
||||
"MBartTokenizer" if is_sentencepiece_available() else None,
|
||||
"MBartTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
@ -213,37 +152,17 @@ else:
|
||||
"MBart50TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"rembert",
|
||||
(
|
||||
"RemBertTokenizer" if is_sentencepiece_available() else None,
|
||||
"RemBertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"clip",
|
||||
(
|
||||
"CLIPTokenizer",
|
||||
"CLIPTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
|
||||
(
|
||||
"perceiver",
|
||||
(
|
||||
"PerceiverTokenizer",
|
||||
None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"xglm",
|
||||
(
|
||||
"XGLMTokenizer" if is_sentencepiece_available() else None,
|
||||
"XGLMTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"mt5",
|
||||
(
|
||||
"MT5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"nystromformer",
|
||||
(
|
||||
@ -251,7 +170,89 @@ else:
|
||||
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("opt", ("GPT2Tokenizer", None)),
|
||||
(
|
||||
"pegasus",
|
||||
(
|
||||
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
||||
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"perceiver",
|
||||
(
|
||||
"PerceiverTokenizer",
|
||||
None,
|
||||
),
|
||||
),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("rag", ("RagTokenizer", None)),
|
||||
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"reformer",
|
||||
(
|
||||
"ReformerTokenizer" if is_sentencepiece_available() else None,
|
||||
"ReformerTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"rembert",
|
||||
(
|
||||
"RemBertTokenizer" if is_sentencepiece_available() else None,
|
||||
"RemBertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
|
||||
(
|
||||
"squeezebert",
|
||||
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
(
|
||||
"t5",
|
||||
(
|
||||
"T5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("tapex", ("TapexTokenizer", None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
|
||||
(
|
||||
"xglm",
|
||||
(
|
||||
"XGLMTokenizer" if is_sentencepiece_available() else None,
|
||||
"XGLMTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("xlm", ("XLMTokenizer", None)),
|
||||
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
(
|
||||
"xlm-roberta",
|
||||
(
|
||||
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
||||
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"xlnet",
|
||||
(
|
||||
"XLNetTokenizer" if is_sentencepiece_available() else None,
|
||||
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"yoso",
|
||||
(
|
||||
@ -259,7 +260,6 @@ else:
|
||||
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
]
|
||||
)
|
||||
|
||||
@ -277,7 +277,10 @@ def tokenizer_class_from_name(class_name: str):
|
||||
module_name = model_type_to_module_name(module_name)
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattr(module, class_name)
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
|
||||
for tokenizer in tokenizers:
|
||||
|
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase):
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
||||
def test_pattern_matching_fallback(self):
|
||||
"""
|
||||
In cases where config.json doesn't include a model_type,
|
||||
perform a few safety checks on the config mapping's order.
|
||||
"""
|
||||
# no key string should be included in a later key string (typical failure case)
|
||||
keys = list(CONFIG_MAPPING.keys())
|
||||
for i, key in enumerate(keys):
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# This model name contains bert and roberta, but roberta ends up being picked.
|
||||
folder = os.path.join(tmp_dir, "fake-roberta")
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
with open(os.path.join(folder, "config.json"), "w") as f:
|
||||
f.write(json.dumps({}))
|
||||
config = AutoConfig.from_pretrained(folder)
|
||||
self.assertEqual(type(config), RobertaConfig)
|
||||
|
||||
def test_new_config_registration(self):
|
||||
try:
|
||||
|
89
utils/sort_auto_mappings.py
Normal file
89
utils/sort_auto_mappings.py
Normal file
@ -0,0 +1,89 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
|
||||
|
||||
|
||||
# re pattern that matches mapping introductions:
|
||||
# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
|
||||
_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
|
||||
# re pattern that matches identifiers in mappings
|
||||
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
||||
|
||||
|
||||
def sort_auto_mapping(fname, overwrite: bool = False):
|
||||
with open(fname, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
line_idx = 0
|
||||
while line_idx < len(lines):
|
||||
if _re_intro_mapping.search(lines[line_idx]) is not None:
|
||||
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
|
||||
# Start of a new mapping!
|
||||
while not lines[line_idx].startswith(" " * indent + "("):
|
||||
new_lines.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
blocks = []
|
||||
while lines[line_idx].strip() != "]":
|
||||
# Blocks either fit in one line or not
|
||||
if lines[line_idx].strip() == "(":
|
||||
start_idx = line_idx
|
||||
while not lines[line_idx].startswith(" " * indent + ")"):
|
||||
line_idx += 1
|
||||
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
|
||||
else:
|
||||
blocks.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
# Sort blocks by their identifiers
|
||||
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
|
||||
new_lines += blocks
|
||||
else:
|
||||
new_lines.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
if overwrite:
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(new_lines))
|
||||
elif "\n".join(new_lines) != content:
|
||||
return True
|
||||
|
||||
|
||||
def sort_all_auto_mappings(overwrite: bool = False):
|
||||
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
|
||||
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
|
||||
|
||||
if not overwrite and any(diffs):
|
||||
failures = [f for f, d in zip(fnames, diffs) if d]
|
||||
raise ValueError(
|
||||
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
|
||||
" this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_all_auto_mappings(not args.check_only)
|
Loading…
Reference in New Issue
Block a user