mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[WIP] Disentangle auto modules from other modeling files (#13023)
* Initial work * All auto models * All tf auto models * All flax auto models * Tokenizers * Add feature extractors * Fix typos * Fix other typo * Use the right config * Remove old mapping names and update logic in AutoTokenizer * Update check_table * Fix copies and check_repo script * Fix last test * Add back name * clean up * Update template * Update template * Forgot a ) * Use alternative to fixup * Fix TF model template * Address review comments * Address review comments * Style
This commit is contained in:
parent
2e4082364e
commit
9870093f7b
2
.github/workflows/model-templates.yml
vendored
2
.github/workflows/model-templates.yml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
- name: Run style changes
|
||||
run: |
|
||||
git fetch origin master:master
|
||||
make fixup
|
||||
make style && make quality
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ always() }}
|
||||
|
1
Makefile
1
Makefile
@ -30,7 +30,6 @@ deps_table_check_updated:
|
||||
# autogenerating code
|
||||
|
||||
autogenerate_code: deps_table_update
|
||||
python utils/class_mapping_update.py
|
||||
|
||||
# Check that source code meets quality standards
|
||||
|
||||
|
@ -213,6 +213,7 @@ _import_structure = {
|
||||
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
||||
"models.marian": ["MarianConfig"],
|
||||
"models.mbart": ["MBartConfig"],
|
||||
"models.mbart50": [],
|
||||
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
|
||||
"models.mmbt": ["MMBTConfig"],
|
||||
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
|
||||
@ -315,7 +316,7 @@ if is_sentencepiece_available():
|
||||
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
|
||||
_import_structure["models.marian"].append("MarianTokenizer")
|
||||
_import_structure["models.mbart"].append("MBartTokenizer")
|
||||
_import_structure["models.mbart"].append("MBart50Tokenizer")
|
||||
_import_structure["models.mbart50"].append("MBart50Tokenizer")
|
||||
_import_structure["models.mt5"].append("MT5Tokenizer")
|
||||
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
||||
_import_structure["models.reformer"].append("ReformerTokenizer")
|
||||
@ -358,7 +359,7 @@ if is_tokenizers_available():
|
||||
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
||||
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
||||
_import_structure["models.mbart"].append("MBartTokenizerFast")
|
||||
_import_structure["models.mbart"].append("MBart50TokenizerFast")
|
||||
_import_structure["models.mbart50"].append("MBart50TokenizerFast")
|
||||
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
|
||||
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
|
||||
_import_structure["models.mt5"].append("MT5TokenizerFast")
|
||||
@ -2021,7 +2022,8 @@ if TYPE_CHECKING:
|
||||
from .models.led import LEDTokenizerFast
|
||||
from .models.longformer import LongformerTokenizerFast
|
||||
from .models.lxmert import LxmertTokenizerFast
|
||||
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
|
||||
from .models.mbart import MBartTokenizerFast
|
||||
from .models.mbart50 import MBart50TokenizerFast
|
||||
from .models.mobilebert import MobileBertTokenizerFast
|
||||
from .models.mpnet import MPNetTokenizerFast
|
||||
from .models.mt5 import MT5TokenizerFast
|
||||
|
@ -41,9 +41,7 @@ from .file_utils import (
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from .training_args import ParallelMode
|
||||
from .utils import logging
|
||||
from .utils.modeling_auto_mapping import (
|
||||
from .models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
||||
@ -54,6 +52,8 @@ from .utils.modeling_auto_mapping import (
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
|
||||
)
|
||||
from .training_args import ParallelMode
|
||||
from .utils import logging
|
||||
|
||||
|
||||
TASK_MAPPING = {
|
||||
|
@ -37,6 +37,7 @@ from . import (
|
||||
cpm,
|
||||
ctrl,
|
||||
deberta,
|
||||
deberta_v2,
|
||||
deit,
|
||||
detr,
|
||||
dialogpt,
|
||||
@ -50,6 +51,8 @@ from . import (
|
||||
gpt2,
|
||||
gpt_neo,
|
||||
herbert,
|
||||
hubert,
|
||||
ibert,
|
||||
layoutlm,
|
||||
led,
|
||||
longformer,
|
||||
@ -58,6 +61,7 @@ from . import (
|
||||
m2m_100,
|
||||
marian,
|
||||
mbart,
|
||||
mbart50,
|
||||
megatron_bert,
|
||||
mmbt,
|
||||
mobilebert,
|
||||
@ -82,6 +86,7 @@ from . import (
|
||||
vit,
|
||||
wav2vec2,
|
||||
xlm,
|
||||
xlm_prophetnet,
|
||||
xlm_roberta,
|
||||
xlnet,
|
||||
)
|
||||
|
@ -13,11 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Factory function to build auto-model classes."""
|
||||
import importlib
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import copy_func
|
||||
from ...utils import logging
|
||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
|
||||
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
|
||||
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
|
||||
from_config.__doc__ = from_config_docstring
|
||||
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
|
||||
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
|
||||
cls.from_config = classmethod(from_config)
|
||||
|
||||
if name.startswith("TF"):
|
||||
@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
|
||||
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
|
||||
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
|
||||
from_pretrained.__doc__ = from_pretrained_docstring
|
||||
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
|
||||
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
|
||||
cls.from_pretrained = classmethod(from_pretrained)
|
||||
return cls
|
||||
|
||||
@ -445,3 +447,79 @@ def get_values(model_mapping):
|
||||
result.append(model)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def getattribute_from_module(module, attr):
|
||||
if attr is None:
|
||||
return None
|
||||
if isinstance(attr, tuple):
|
||||
return tuple(getattribute_from_module(module, a) for a in attr)
|
||||
if hasattr(module, attr):
|
||||
return getattr(module, attr)
|
||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
|
||||
# object at the top level.
|
||||
transformers_module = importlib.import_module("transformers")
|
||||
return getattribute_from_module(transformers_module, attr)
|
||||
|
||||
|
||||
class _LazyAutoMapping(OrderedDict):
|
||||
"""
|
||||
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
||||
|
||||
Args:
|
||||
|
||||
- config_mapping: The map model type to config class
|
||||
- model_mapping: The map model type to model (or tokenizer) class
|
||||
"""
|
||||
|
||||
def __init__(self, config_mapping, model_mapping):
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
self._modules = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type not in self._model_mapping:
|
||||
raise KeyError(key)
|
||||
model_name = self._model_mapping[model_type]
|
||||
return self._load_attr_from_module(model_type, model_name)
|
||||
|
||||
def _load_attr_from_module(self, model_type, attr):
|
||||
module_name = model_type_to_module_name(model_type)
|
||||
if module_name not in self._modules:
|
||||
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
def keys(self):
|
||||
return [
|
||||
self._load_attr_from_module(key, name)
|
||||
for key, name in self._config_mapping.items()
|
||||
if key in self._model_mapping.keys()
|
||||
]
|
||||
|
||||
def values(self):
|
||||
return [
|
||||
self._load_attr_from_module(key, name)
|
||||
for key, name in self._model_mapping.items()
|
||||
if key in self._config_mapping.keys()
|
||||
]
|
||||
|
||||
def items(self):
|
||||
return [
|
||||
(
|
||||
self._load_attr_from_module(key, self._config_mapping[key]),
|
||||
self._load_attr_from_module(key, self._model_mapping[key]),
|
||||
)
|
||||
for key in self._model_mapping.keys()
|
||||
if key in self._config_mapping.keys()
|
||||
]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._mapping.keys())
|
||||
|
||||
def __contains__(self, item):
|
||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
||||
return False
|
||||
model_type = self._reverse_config_mapping[item.__name__]
|
||||
return model_type in self._model_mapping
|
||||
|
@ -13,215 +13,140 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Auto Config class. """
|
||||
|
||||
import importlib
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import List, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from ..beit.configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
|
||||
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
|
||||
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
|
||||
from ..bigbird_pegasus.configuration_bigbird_pegasus import (
|
||||
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BigBirdPegasusConfig,
|
||||
)
|
||||
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
|
||||
from ..blenderbot_small.configuration_blenderbot_small import (
|
||||
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BlenderbotSmallConfig,
|
||||
)
|
||||
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from ..canine.configuration_canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig
|
||||
from ..clip.configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig
|
||||
from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
|
||||
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
|
||||
from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
|
||||
from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
|
||||
from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
|
||||
from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
|
||||
from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
|
||||
from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
|
||||
from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
||||
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
|
||||
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
||||
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
|
||||
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig
|
||||
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
|
||||
from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
|
||||
from ..marian.configuration_marian import MarianConfig
|
||||
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
||||
from ..megatron_bert.configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
|
||||
from ..mobilebert.configuration_mobilebert import MobileBertConfig
|
||||
from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
|
||||
from ..mt5.configuration_mt5 import MT5Config
|
||||
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||
from ..pegasus.configuration_pegasus import PegasusConfig
|
||||
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
|
||||
from ..rag.configuration_rag import RagConfig
|
||||
from ..reformer.configuration_reformer import ReformerConfig
|
||||
from ..rembert.configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
|
||||
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
from ..roformer.configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
|
||||
from ..speech_to_text.configuration_speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
Speech2TextConfig,
|
||||
)
|
||||
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
|
||||
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
||||
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||
from ..visual_bert.configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
|
||||
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
|
||||
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
|
||||
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLMProphetNetConfig,
|
||||
)
|
||||
from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||
from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||
|
||||
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
# Add archive maps here
|
||||
BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DETR_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
|
||||
CONFIG_MAPPING = OrderedDict(
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("beit", BeitConfig),
|
||||
("rembert", RemBertConfig),
|
||||
("visual_bert", VisualBertConfig),
|
||||
("canine", CanineConfig),
|
||||
("roformer", RoFormerConfig),
|
||||
("clip", CLIPConfig),
|
||||
("bigbird_pegasus", BigBirdPegasusConfig),
|
||||
("deit", DeiTConfig),
|
||||
("luke", LukeConfig),
|
||||
("detr", DetrConfig),
|
||||
("gpt_neo", GPTNeoConfig),
|
||||
("big_bird", BigBirdConfig),
|
||||
("speech_to_text", Speech2TextConfig),
|
||||
("vit", ViTConfig),
|
||||
("wav2vec2", Wav2Vec2Config),
|
||||
("m2m_100", M2M100Config),
|
||||
("convbert", ConvBertConfig),
|
||||
("led", LEDConfig),
|
||||
("blenderbot-small", BlenderbotSmallConfig),
|
||||
("retribert", RetriBertConfig),
|
||||
("ibert", IBertConfig),
|
||||
("mt5", MT5Config),
|
||||
("t5", T5Config),
|
||||
("mobilebert", MobileBertConfig),
|
||||
("distilbert", DistilBertConfig),
|
||||
("albert", AlbertConfig),
|
||||
("bert-generation", BertGenerationConfig),
|
||||
("camembert", CamembertConfig),
|
||||
("xlm-roberta", XLMRobertaConfig),
|
||||
("pegasus", PegasusConfig),
|
||||
("marian", MarianConfig),
|
||||
("mbart", MBartConfig),
|
||||
("megatron-bert", MegatronBertConfig),
|
||||
("mpnet", MPNetConfig),
|
||||
("bart", BartConfig),
|
||||
("blenderbot", BlenderbotConfig),
|
||||
("reformer", ReformerConfig),
|
||||
("longformer", LongformerConfig),
|
||||
("roberta", RobertaConfig),
|
||||
("deberta-v2", DebertaV2Config),
|
||||
("deberta", DebertaConfig),
|
||||
("flaubert", FlaubertConfig),
|
||||
("fsmt", FSMTConfig),
|
||||
("squeezebert", SqueezeBertConfig),
|
||||
("hubert", HubertConfig),
|
||||
("bert", BertConfig),
|
||||
("openai-gpt", OpenAIGPTConfig),
|
||||
("gpt2", GPT2Config),
|
||||
("transfo-xl", TransfoXLConfig),
|
||||
("xlnet", XLNetConfig),
|
||||
("xlm-prophetnet", XLMProphetNetConfig),
|
||||
("prophetnet", ProphetNetConfig),
|
||||
("xlm", XLMConfig),
|
||||
("ctrl", CTRLConfig),
|
||||
("electra", ElectraConfig),
|
||||
("encoder-decoder", EncoderDecoderConfig),
|
||||
("funnel", FunnelConfig),
|
||||
("lxmert", LxmertConfig),
|
||||
("dpr", DPRConfig),
|
||||
("layoutlm", LayoutLMConfig),
|
||||
("rag", RagConfig),
|
||||
("tapas", TapasConfig),
|
||||
("beit", "BeitConfig"),
|
||||
("rembert", "RemBertConfig"),
|
||||
("visual_bert", "VisualBertConfig"),
|
||||
("canine", "CanineConfig"),
|
||||
("roformer", "RoFormerConfig"),
|
||||
("clip", "CLIPConfig"),
|
||||
("bigbird_pegasus", "BigBirdPegasusConfig"),
|
||||
("deit", "DeiTConfig"),
|
||||
("luke", "LukeConfig"),
|
||||
("detr", "DetrConfig"),
|
||||
("gpt_neo", "GPTNeoConfig"),
|
||||
("big_bird", "BigBirdConfig"),
|
||||
("speech_to_text", "Speech2TextConfig"),
|
||||
("vit", "ViTConfig"),
|
||||
("wav2vec2", "Wav2Vec2Config"),
|
||||
("m2m_100", "M2M100Config"),
|
||||
("convbert", "ConvBertConfig"),
|
||||
("led", "LEDConfig"),
|
||||
("blenderbot-small", "BlenderbotSmallConfig"),
|
||||
("retribert", "RetriBertConfig"),
|
||||
("ibert", "IBertConfig"),
|
||||
("mt5", "MT5Config"),
|
||||
("t5", "T5Config"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
("distilbert", "DistilBertConfig"),
|
||||
("albert", "AlbertConfig"),
|
||||
("bert-generation", "BertGenerationConfig"),
|
||||
("camembert", "CamembertConfig"),
|
||||
("xlm-roberta", "XLMRobertaConfig"),
|
||||
("pegasus", "PegasusConfig"),
|
||||
("marian", "MarianConfig"),
|
||||
("mbart", "MBartConfig"),
|
||||
("megatron-bert", "MegatronBertConfig"),
|
||||
("mpnet", "MPNetConfig"),
|
||||
("bart", "BartConfig"),
|
||||
("blenderbot", "BlenderbotConfig"),
|
||||
("reformer", "ReformerConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("roberta", "RobertaConfig"),
|
||||
("deberta-v2", "DebertaV2Config"),
|
||||
("deberta", "DebertaConfig"),
|
||||
("flaubert", "FlaubertConfig"),
|
||||
("fsmt", "FSMTConfig"),
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("hubert", "HubertConfig"),
|
||||
("bert", "BertConfig"),
|
||||
("openai-gpt", "OpenAIGPTConfig"),
|
||||
("gpt2", "GPT2Config"),
|
||||
("transfo-xl", "TransfoXLConfig"),
|
||||
("xlnet", "XLNetConfig"),
|
||||
("xlm-prophetnet", "XLMProphetNetConfig"),
|
||||
("prophetnet", "ProphetNetConfig"),
|
||||
("xlm", "XLMConfig"),
|
||||
("ctrl", "CTRLConfig"),
|
||||
("electra", "ElectraConfig"),
|
||||
("encoder-decoder", "EncoderDecoderConfig"),
|
||||
("funnel", "FunnelConfig"),
|
||||
("lxmert", "LxmertConfig"),
|
||||
("dpr", "DPRConfig"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("rag", "RagConfig"),
|
||||
("tapas", "TapasConfig"),
|
||||
]
|
||||
)
|
||||
|
||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add archive maps here
|
||||
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -290,14 +215,136 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mpnet", "MPNet"),
|
||||
("tapas", "TAPAS"),
|
||||
("hubert", "Hubert"),
|
||||
("barthez", "BARThez"),
|
||||
("phobert", "PhoBERT"),
|
||||
("cpm", "CPM"),
|
||||
("bertweet", "Bertweet"),
|
||||
("bert-japanese", "BertJapanese"),
|
||||
("byt5", "ByT5"),
|
||||
("mbart50", "mBART-50"),
|
||||
]
|
||||
)
|
||||
|
||||
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])
|
||||
|
||||
def _get_class_name(model_class):
|
||||
|
||||
def model_type_to_module_name(key):
|
||||
"""Converts a config key to the corresponding module."""
|
||||
# Special treatment
|
||||
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
|
||||
return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
|
||||
|
||||
return key.replace("-", "_")
|
||||
|
||||
|
||||
def config_class_to_model_type(config):
|
||||
"""Converts a config class name to the corresponding model type"""
|
||||
for key, cls in CONFIG_MAPPING_NAMES.items():
|
||||
if cls == config:
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
class _LazyConfigMapping(OrderedDict):
|
||||
"""
|
||||
A dictionary that lazily load its values when they are requested.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping):
|
||||
self._mapping = mapping
|
||||
self._modules = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self._mapping:
|
||||
raise KeyError(key)
|
||||
value = self._mapping[key]
|
||||
module_name = model_type_to_module_name(key)
|
||||
if module_name not in self._modules:
|
||||
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattr(self._modules[module_name], value)
|
||||
|
||||
def keys(self):
|
||||
return self._mapping.keys()
|
||||
|
||||
def values(self):
|
||||
return [self[k] for k in self._mapping.keys()]
|
||||
|
||||
def items(self):
|
||||
return [(k, self[k]) for k in self._mapping.keys()]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._mapping.keys())
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._mapping
|
||||
|
||||
|
||||
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
|
||||
|
||||
class _LazyLoadAllMappings(OrderedDict):
|
||||
"""
|
||||
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
|
||||
etc.)
|
||||
|
||||
Args:
|
||||
mapping: The mapping to load.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping):
|
||||
self._mapping = mapping
|
||||
self._initialized = False
|
||||
self._data = {}
|
||||
|
||||
def _initialize(self):
|
||||
if self._initialized:
|
||||
return
|
||||
warnings.warn(
|
||||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
|
||||
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
for model_type, map_name in self._mapping.items():
|
||||
module_name = model_type_to_module_name(model_type)
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
mapping = getattr(module, map_name)
|
||||
self._data.update(mapping)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def __getitem__(self, key):
|
||||
self._initialize()
|
||||
return self._data[key]
|
||||
|
||||
def keys(self):
|
||||
self._initialize()
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
self._initialize()
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
self._initialize()
|
||||
return self._data.keys()
|
||||
|
||||
def __iter__(self):
|
||||
self._initialize()
|
||||
return iter(self._data)
|
||||
|
||||
def __contains__(self, item):
|
||||
self._initialize()
|
||||
return item in self._data
|
||||
|
||||
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)
|
||||
|
||||
|
||||
def _get_class_name(model_class: Union[str, List[str]]):
|
||||
if isinstance(model_class, (list, tuple)):
|
||||
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
|
||||
return f":class:`~transformers.{model_class.__name__}`"
|
||||
return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None])
|
||||
return f":class:`~transformers.{model_class}`"
|
||||
|
||||
|
||||
def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||
@ -306,23 +353,26 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||
if use_model_types:
|
||||
if config_to_class is None:
|
||||
model_type_to_name = {
|
||||
model_type: f":class:`~transformers.{config.__name__}`"
|
||||
for model_type, config in CONFIG_MAPPING.items()
|
||||
model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
|
||||
}
|
||||
else:
|
||||
model_type_to_name = {
|
||||
model_type: _get_class_name(config_to_class[config])
|
||||
for model_type, config in CONFIG_MAPPING.items()
|
||||
if config in config_to_class
|
||||
model_type: _get_class_name(model_class)
|
||||
for model_type, model_class in config_to_class.items()
|
||||
if model_type in MODEL_NAMES_MAPPING
|
||||
}
|
||||
lines = [
|
||||
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
|
||||
for model_type in sorted(model_type_to_name.keys())
|
||||
]
|
||||
else:
|
||||
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
|
||||
config_to_name = {
|
||||
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
|
||||
for config, clas in config_to_class.items()
|
||||
if config in CONFIG_MAPPING_NAMES
|
||||
}
|
||||
config_to_model_name = {
|
||||
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
|
||||
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
|
||||
}
|
||||
lines = [
|
||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
|
||||
|
@ -13,36 +13,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" AutoFeatureExtractor class. """
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from transformers import BeitFeatureExtractor, DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
|
||||
|
||||
from ... import BeitConfig, DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
|
||||
# Build the list of all feature extractors
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import FEATURE_EXTRACTOR_NAME
|
||||
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||
|
||||
|
||||
FEATURE_EXTRACTOR_MAPPING = OrderedDict(
|
||||
[
|
||||
(BeitConfig, BeitFeatureExtractor),
|
||||
(DeiTConfig, DeiTFeatureExtractor),
|
||||
(Speech2TextConfig, Speech2TextFeatureExtractor),
|
||||
(ViTConfig, ViTFeatureExtractor),
|
||||
(Wav2Vec2Config, Wav2Vec2FeatureExtractor),
|
||||
]
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
config_class_to_model_type,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
|
||||
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("beit", "BeitFeatureExtractor"),
|
||||
("deit", "DeiTFeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("vit", "ViTFeatureExtractor"),
|
||||
("wav2vec2", "Wav2Vec2FeatureExtractor"),
|
||||
]
|
||||
)
|
||||
|
||||
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
|
||||
|
||||
|
||||
def feature_extractor_class_from_name(class_name: str):
|
||||
for c in FEATURE_EXTRACTOR_MAPPING.values():
|
||||
if c is not None and c.__name__ == class_name:
|
||||
return c
|
||||
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
|
||||
if class_name in extractors:
|
||||
break
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
class AutoFeatureExtractor:
|
||||
@ -60,7 +67,7 @@ class AutoFeatureExtractor:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING)
|
||||
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
|
||||
@ -142,7 +149,8 @@ class AutoFeatureExtractor:
|
||||
kwargs["_from_auto"] = True
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if type(config) in FEATURE_EXTRACTOR_MAPPING.keys():
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
if model_type is not None:
|
||||
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
|
||||
elif "feature_extractor_type" in config_dict:
|
||||
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -18,207 +18,163 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ..bart.modeling_flax_bart import (
|
||||
FlaxBartForConditionalGeneration,
|
||||
FlaxBartForQuestionAnswering,
|
||||
FlaxBartForSequenceClassification,
|
||||
FlaxBartModel,
|
||||
)
|
||||
from ..bert.modeling_flax_bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForSequenceClassification,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from ..big_bird.modeling_flax_big_bird import (
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForQuestionAnswering,
|
||||
FlaxBigBirdForSequenceClassification,
|
||||
FlaxBigBirdForTokenClassification,
|
||||
FlaxBigBirdModel,
|
||||
)
|
||||
from ..clip.modeling_flax_clip import FlaxCLIPModel
|
||||
from ..electra.modeling_flax_electra import (
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForMultipleChoice,
|
||||
FlaxElectraForPreTraining,
|
||||
FlaxElectraForQuestionAnswering,
|
||||
FlaxElectraForSequenceClassification,
|
||||
FlaxElectraForTokenClassification,
|
||||
FlaxElectraModel,
|
||||
)
|
||||
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
|
||||
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel
|
||||
from ..mbart.modeling_flax_mbart import (
|
||||
FlaxMBartForConditionalGeneration,
|
||||
FlaxMBartForQuestionAnswering,
|
||||
FlaxMBartForSequenceClassification,
|
||||
FlaxMBartModel,
|
||||
)
|
||||
from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
|
||||
from ..roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
|
||||
from .auto_factory import _BaseAutoModelClass, auto_class_update
|
||||
from .configuration_auto import (
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BigBirdConfig,
|
||||
CLIPConfig,
|
||||
ElectraConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MT5Config,
|
||||
RobertaConfig,
|
||||
T5Config,
|
||||
ViTConfig,
|
||||
Wav2Vec2Config,
|
||||
)
|
||||
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FLAX_MODEL_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
(BertConfig, FlaxBertModel),
|
||||
(BigBirdConfig, FlaxBigBirdModel),
|
||||
(BartConfig, FlaxBartModel),
|
||||
(GPT2Config, FlaxGPT2Model),
|
||||
(GPTNeoConfig, FlaxGPTNeoModel),
|
||||
(ElectraConfig, FlaxElectraModel),
|
||||
(CLIPConfig, FlaxCLIPModel),
|
||||
(ViTConfig, FlaxViTModel),
|
||||
(MBartConfig, FlaxMBartModel),
|
||||
(T5Config, FlaxT5Model),
|
||||
(MT5Config, FlaxMT5Model),
|
||||
(Wav2Vec2Config, FlaxWav2Vec2Model),
|
||||
(MarianConfig, FlaxMarianModel),
|
||||
("roberta", "FlaxRobertaModel"),
|
||||
("bert", "FlaxBertModel"),
|
||||
("big_bird", "FlaxBigBirdModel"),
|
||||
("bart", "FlaxBartModel"),
|
||||
("gpt2", "FlaxGPT2Model"),
|
||||
("gpt_neo", "FlaxGPTNeoModel"),
|
||||
("electra", "FlaxElectraModel"),
|
||||
("clip", "FlaxCLIPModel"),
|
||||
("vit", "FlaxViTModel"),
|
||||
("mbart", "FlaxMBartModel"),
|
||||
("t5", "FlaxT5Model"),
|
||||
("mt5", "FlaxMT5Model"),
|
||||
("wav2vec2", "FlaxWav2Vec2Model"),
|
||||
("marian", "FlaxMarianModel"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||
(BertConfig, FlaxBertForPreTraining),
|
||||
(BigBirdConfig, FlaxBigBirdForPreTraining),
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(ElectraConfig, FlaxElectraForPreTraining),
|
||||
(MBartConfig, FlaxMBartForConditionalGeneration),
|
||||
(T5Config, FlaxT5ForConditionalGeneration),
|
||||
(MT5Config, FlaxMT5ForConditionalGeneration),
|
||||
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("bert", "FlaxBertForPreTraining"),
|
||||
("big_bird", "FlaxBigBirdForPreTraining"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("electra", "FlaxElectraForPreTraining"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||
(BertConfig, FlaxBertForMaskedLM),
|
||||
(BigBirdConfig, FlaxBigBirdForMaskedLM),
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(ElectraConfig, FlaxElectraForMaskedLM),
|
||||
(MBartConfig, FlaxMBartForConditionalGeneration),
|
||||
("roberta", "FlaxRobertaForMaskedLM"),
|
||||
("bert", "FlaxBertForMaskedLM"),
|
||||
("big_bird", "FlaxBigBirdForMaskedLM"),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("electra", "FlaxElectraForMaskedLM"),
|
||||
("mbart", "FlaxMBartForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(T5Config, FlaxT5ForConditionalGeneration),
|
||||
(MT5Config, FlaxMT5ForConditionalGeneration),
|
||||
(MarianConfig, FlaxMarianMTModel),
|
||||
("bart", "FlaxBartForConditionalGeneration"),
|
||||
("t5", "FlaxT5ForConditionalGeneration"),
|
||||
("mt5", "FlaxMT5ForConditionalGeneration"),
|
||||
("marian", "FlaxMarianMTModel"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
(ViTConfig, FlaxViTForImageClassification),
|
||||
("vit", "FlaxViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
(GPT2Config, FlaxGPT2LMHeadModel),
|
||||
(GPTNeoConfig, FlaxGPTNeoForCausalLM),
|
||||
("gpt2", "FlaxGPT2LMHeadModel"),
|
||||
("gpt_neo", "FlaxGPTNeoForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(RobertaConfig, FlaxRobertaForSequenceClassification),
|
||||
(BertConfig, FlaxBertForSequenceClassification),
|
||||
(BigBirdConfig, FlaxBigBirdForSequenceClassification),
|
||||
(BartConfig, FlaxBartForSequenceClassification),
|
||||
(ElectraConfig, FlaxElectraForSequenceClassification),
|
||||
(MBartConfig, FlaxMBartForSequenceClassification),
|
||||
("roberta", "FlaxRobertaForSequenceClassification"),
|
||||
("bert", "FlaxBertForSequenceClassification"),
|
||||
("big_bird", "FlaxBigBirdForSequenceClassification"),
|
||||
("bart", "FlaxBartForSequenceClassification"),
|
||||
("electra", "FlaxElectraForSequenceClassification"),
|
||||
("mbart", "FlaxMBartForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(RobertaConfig, FlaxRobertaForQuestionAnswering),
|
||||
(BertConfig, FlaxBertForQuestionAnswering),
|
||||
(BigBirdConfig, FlaxBigBirdForQuestionAnswering),
|
||||
(BartConfig, FlaxBartForQuestionAnswering),
|
||||
(ElectraConfig, FlaxElectraForQuestionAnswering),
|
||||
(MBartConfig, FlaxMBartForQuestionAnswering),
|
||||
("roberta", "FlaxRobertaForQuestionAnswering"),
|
||||
("bert", "FlaxBertForQuestionAnswering"),
|
||||
("big_bird", "FlaxBigBirdForQuestionAnswering"),
|
||||
("bart", "FlaxBartForQuestionAnswering"),
|
||||
("electra", "FlaxElectraForQuestionAnswering"),
|
||||
("mbart", "FlaxMBartForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
(RobertaConfig, FlaxRobertaForTokenClassification),
|
||||
(BertConfig, FlaxBertForTokenClassification),
|
||||
(BigBirdConfig, FlaxBigBirdForTokenClassification),
|
||||
(ElectraConfig, FlaxElectraForTokenClassification),
|
||||
("roberta", "FlaxRobertaForTokenClassification"),
|
||||
("bert", "FlaxBertForTokenClassification"),
|
||||
("big_bird", "FlaxBigBirdForTokenClassification"),
|
||||
("electra", "FlaxElectraForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
(RobertaConfig, FlaxRobertaForMultipleChoice),
|
||||
(BertConfig, FlaxBertForMultipleChoice),
|
||||
(BigBirdConfig, FlaxBigBirdForMultipleChoice),
|
||||
(ElectraConfig, FlaxElectraForMultipleChoice),
|
||||
("roberta", "FlaxRobertaForMultipleChoice"),
|
||||
("bert", "FlaxBertForMultipleChoice"),
|
||||
("big_bird", "FlaxBigBirdForMultipleChoice"),
|
||||
("electra", "FlaxElectraForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
(BertConfig, FlaxBertForNextSentencePrediction),
|
||||
("bert", "FlaxBertForNextSentencePrediction"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
||||
)
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
class FlaxAutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = FLAX_MODEL_MAPPING
|
||||
|
||||
|
@ -19,492 +19,298 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
# Add modeling imports here
|
||||
from ..albert.modeling_tf_albert import (
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForMultipleChoice,
|
||||
TFAlbertForPreTraining,
|
||||
TFAlbertForQuestionAnswering,
|
||||
TFAlbertForSequenceClassification,
|
||||
TFAlbertForTokenClassification,
|
||||
TFAlbertModel,
|
||||
)
|
||||
from ..bart.modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
|
||||
from ..bert.modeling_tf_bert import (
|
||||
TFBertForMaskedLM,
|
||||
TFBertForMultipleChoice,
|
||||
TFBertForNextSentencePrediction,
|
||||
TFBertForPreTraining,
|
||||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertForTokenClassification,
|
||||
TFBertLMHeadModel,
|
||||
TFBertModel,
|
||||
)
|
||||
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||
from ..blenderbot_small.modeling_tf_blenderbot_small import (
|
||||
TFBlenderbotSmallForConditionalGeneration,
|
||||
TFBlenderbotSmallModel,
|
||||
)
|
||||
from ..camembert.modeling_tf_camembert import (
|
||||
TFCamembertForMaskedLM,
|
||||
TFCamembertForMultipleChoice,
|
||||
TFCamembertForQuestionAnswering,
|
||||
TFCamembertForSequenceClassification,
|
||||
TFCamembertForTokenClassification,
|
||||
TFCamembertModel,
|
||||
)
|
||||
from ..convbert.modeling_tf_convbert import (
|
||||
TFConvBertForMaskedLM,
|
||||
TFConvBertForMultipleChoice,
|
||||
TFConvBertForQuestionAnswering,
|
||||
TFConvBertForSequenceClassification,
|
||||
TFConvBertForTokenClassification,
|
||||
TFConvBertModel,
|
||||
)
|
||||
from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel
|
||||
from ..distilbert.modeling_tf_distilbert import (
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForMultipleChoice,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertModel,
|
||||
)
|
||||
from ..dpr.modeling_tf_dpr import TFDPRQuestionEncoder
|
||||
from ..electra.modeling_tf_electra import (
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForMultipleChoice,
|
||||
TFElectraForPreTraining,
|
||||
TFElectraForQuestionAnswering,
|
||||
TFElectraForSequenceClassification,
|
||||
TFElectraForTokenClassification,
|
||||
TFElectraModel,
|
||||
)
|
||||
from ..flaubert.modeling_tf_flaubert import (
|
||||
TFFlaubertForMultipleChoice,
|
||||
TFFlaubertForQuestionAnsweringSimple,
|
||||
TFFlaubertForSequenceClassification,
|
||||
TFFlaubertForTokenClassification,
|
||||
TFFlaubertModel,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
from ..funnel.modeling_tf_funnel import (
|
||||
TFFunnelBaseModel,
|
||||
TFFunnelForMaskedLM,
|
||||
TFFunnelForMultipleChoice,
|
||||
TFFunnelForPreTraining,
|
||||
TFFunnelForQuestionAnswering,
|
||||
TFFunnelForSequenceClassification,
|
||||
TFFunnelForTokenClassification,
|
||||
TFFunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
|
||||
from ..hubert.modeling_tf_hubert import TFHubertModel
|
||||
from ..layoutlm.modeling_tf_layoutlm import (
|
||||
TFLayoutLMForMaskedLM,
|
||||
TFLayoutLMForSequenceClassification,
|
||||
TFLayoutLMForTokenClassification,
|
||||
TFLayoutLMModel,
|
||||
)
|
||||
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
|
||||
from ..longformer.modeling_tf_longformer import (
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
)
|
||||
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
|
||||
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
|
||||
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||
from ..mobilebert.modeling_tf_mobilebert import (
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForMultipleChoice,
|
||||
TFMobileBertForNextSentencePrediction,
|
||||
TFMobileBertForPreTraining,
|
||||
TFMobileBertForQuestionAnswering,
|
||||
TFMobileBertForSequenceClassification,
|
||||
TFMobileBertForTokenClassification,
|
||||
TFMobileBertModel,
|
||||
)
|
||||
from ..mpnet.modeling_tf_mpnet import (
|
||||
TFMPNetForMaskedLM,
|
||||
TFMPNetForMultipleChoice,
|
||||
TFMPNetForQuestionAnswering,
|
||||
TFMPNetForSequenceClassification,
|
||||
TFMPNetForTokenClassification,
|
||||
TFMPNetModel,
|
||||
)
|
||||
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||
from ..rembert.modeling_tf_rembert import (
|
||||
TFRemBertForCausalLM,
|
||||
TFRemBertForMaskedLM,
|
||||
TFRemBertForMultipleChoice,
|
||||
TFRemBertForQuestionAnswering,
|
||||
TFRemBertForSequenceClassification,
|
||||
TFRemBertForTokenClassification,
|
||||
TFRemBertModel,
|
||||
)
|
||||
from ..roberta.modeling_tf_roberta import (
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForMultipleChoice,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaModel,
|
||||
)
|
||||
from ..roformer.modeling_tf_roformer import (
|
||||
TFRoFormerForCausalLM,
|
||||
TFRoFormerForMaskedLM,
|
||||
TFRoFormerForMultipleChoice,
|
||||
TFRoFormerForQuestionAnswering,
|
||||
TFRoFormerForSequenceClassification,
|
||||
TFRoFormerForTokenClassification,
|
||||
TFRoFormerModel,
|
||||
)
|
||||
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
||||
from ..transfo_xl.modeling_tf_transfo_xl import (
|
||||
TFTransfoXLForSequenceClassification,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TFTransfoXLModel,
|
||||
)
|
||||
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
|
||||
from ..xlm.modeling_tf_xlm import (
|
||||
TFXLMForMultipleChoice,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForTokenClassification,
|
||||
TFXLMModel,
|
||||
TFXLMWithLMHeadModel,
|
||||
)
|
||||
from ..xlm_roberta.modeling_tf_xlm_roberta import (
|
||||
TFXLMRobertaForMaskedLM,
|
||||
TFXLMRobertaForMultipleChoice,
|
||||
TFXLMRobertaForQuestionAnswering,
|
||||
TFXLMRobertaForSequenceClassification,
|
||||
TFXLMRobertaForTokenClassification,
|
||||
TFXLMRobertaModel,
|
||||
)
|
||||
from ..xlnet.modeling_tf_xlnet import (
|
||||
TFXLNetForMultipleChoice,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetLMHeadModel,
|
||||
TFXLNetModel,
|
||||
)
|
||||
from .auto_factory import _BaseAutoModelClass, auto_class_update
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BlenderbotConfig,
|
||||
BlenderbotSmallConfig,
|
||||
CamembertConfig,
|
||||
ConvBertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
DPRConfig,
|
||||
ElectraConfig,
|
||||
FlaubertConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
HubertConfig,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
LxmertConfig,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
MPNetConfig,
|
||||
MT5Config,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
RemBertConfig,
|
||||
RobertaConfig,
|
||||
RoFormerConfig,
|
||||
T5Config,
|
||||
TransfoXLConfig,
|
||||
Wav2Vec2Config,
|
||||
XLMConfig,
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
)
|
||||
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
TF_MODEL_MAPPING = OrderedDict(
|
||||
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(RemBertConfig, TFRemBertModel),
|
||||
(RoFormerConfig, TFRoFormerModel),
|
||||
(ConvBertConfig, TFConvBertModel),
|
||||
(LEDConfig, TFLEDModel),
|
||||
(LxmertConfig, TFLxmertModel),
|
||||
(MT5Config, TFMT5Model),
|
||||
(T5Config, TFT5Model),
|
||||
(DistilBertConfig, TFDistilBertModel),
|
||||
(AlbertConfig, TFAlbertModel),
|
||||
(BartConfig, TFBartModel),
|
||||
(CamembertConfig, TFCamembertModel),
|
||||
(XLMRobertaConfig, TFXLMRobertaModel),
|
||||
(LongformerConfig, TFLongformerModel),
|
||||
(RobertaConfig, TFRobertaModel),
|
||||
(LayoutLMConfig, TFLayoutLMModel),
|
||||
(BertConfig, TFBertModel),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
||||
(GPT2Config, TFGPT2Model),
|
||||
(MobileBertConfig, TFMobileBertModel),
|
||||
(TransfoXLConfig, TFTransfoXLModel),
|
||||
(XLNetConfig, TFXLNetModel),
|
||||
(FlaubertConfig, TFFlaubertModel),
|
||||
(XLMConfig, TFXLMModel),
|
||||
(CTRLConfig, TFCTRLModel),
|
||||
(ElectraConfig, TFElectraModel),
|
||||
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
|
||||
(DPRConfig, TFDPRQuestionEncoder),
|
||||
(MPNetConfig, TFMPNetModel),
|
||||
(BartConfig, TFBartModel),
|
||||
(MBartConfig, TFMBartModel),
|
||||
(MarianConfig, TFMarianModel),
|
||||
(PegasusConfig, TFPegasusModel),
|
||||
(BlenderbotConfig, TFBlenderbotModel),
|
||||
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
|
||||
(Wav2Vec2Config, TFWav2Vec2Model),
|
||||
(HubertConfig, TFHubertModel),
|
||||
("rembert", "TFRemBertModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
("led", "TFLEDModel"),
|
||||
("lxmert", "TFLxmertModel"),
|
||||
("mt5", "TFMT5Model"),
|
||||
("t5", "TFT5Model"),
|
||||
("distilbert", "TFDistilBertModel"),
|
||||
("albert", "TFAlbertModel"),
|
||||
("bart", "TFBartModel"),
|
||||
("camembert", "TFCamembertModel"),
|
||||
("xlm-roberta", "TFXLMRobertaModel"),
|
||||
("longformer", "TFLongformerModel"),
|
||||
("roberta", "TFRobertaModel"),
|
||||
("layoutlm", "TFLayoutLMModel"),
|
||||
("bert", "TFBertModel"),
|
||||
("openai-gpt", "TFOpenAIGPTModel"),
|
||||
("gpt2", "TFGPT2Model"),
|
||||
("mobilebert", "TFMobileBertModel"),
|
||||
("transfo-xl", "TFTransfoXLModel"),
|
||||
("xlnet", "TFXLNetModel"),
|
||||
("flaubert", "TFFlaubertModel"),
|
||||
("xlm", "TFXLMModel"),
|
||||
("ctrl", "TFCTRLModel"),
|
||||
("electra", "TFElectraModel"),
|
||||
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
|
||||
("dpr", "TFDPRQuestionEncoder"),
|
||||
("mpnet", "TFMPNetModel"),
|
||||
("mbart", "TFMBartModel"),
|
||||
("marian", "TFMarianModel"),
|
||||
("pegasus", "TFPegasusModel"),
|
||||
("blenderbot", "TFBlenderbotModel"),
|
||||
("blenderbot-small", "TFBlenderbotSmallModel"),
|
||||
("wav2vec2", "TFWav2Vec2Model"),
|
||||
("hubert", "TFHubertModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
(LxmertConfig, TFLxmertForPreTraining),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForPreTraining),
|
||||
(BartConfig, TFBartForConditionalGeneration),
|
||||
(CamembertConfig, TFCamembertForMaskedLM),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||
(RobertaConfig, TFRobertaForMaskedLM),
|
||||
(LayoutLMConfig, TFLayoutLMForMaskedLM),
|
||||
(BertConfig, TFBertForPreTraining),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
(MobileBertConfig, TFMobileBertForPreTraining),
|
||||
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||
(XLNetConfig, TFXLNetLMHeadModel),
|
||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(CTRLConfig, TFCTRLLMHeadModel),
|
||||
(ElectraConfig, TFElectraForPreTraining),
|
||||
(FunnelConfig, TFFunnelForPreTraining),
|
||||
(MPNetConfig, TFMPNetForMaskedLM),
|
||||
("lxmert", "TFLxmertForPreTraining"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("albert", "TFAlbertForPreTraining"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("bert", "TFBertForPreTraining"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("mobilebert", "TFMobileBertForPreTraining"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("electra", "TFElectraForPreTraining"),
|
||||
("funnel", "TFFunnelForPreTraining"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
(RemBertConfig, TFRemBertForMaskedLM),
|
||||
(RoFormerConfig, TFRoFormerForMaskedLM),
|
||||
(ConvBertConfig, TFConvBertForMaskedLM),
|
||||
(LEDConfig, TFLEDForConditionalGeneration),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
(MarianConfig, TFMarianMTModel),
|
||||
(BartConfig, TFBartForConditionalGeneration),
|
||||
(CamembertConfig, TFCamembertForMaskedLM),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||
(LongformerConfig, TFLongformerForMaskedLM),
|
||||
(RobertaConfig, TFRobertaForMaskedLM),
|
||||
(LayoutLMConfig, TFLayoutLMForMaskedLM),
|
||||
(BertConfig, TFBertForMaskedLM),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
(MobileBertConfig, TFMobileBertForMaskedLM),
|
||||
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||
(XLNetConfig, TFXLNetLMHeadModel),
|
||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(CTRLConfig, TFCTRLLMHeadModel),
|
||||
(ElectraConfig, TFElectraForMaskedLM),
|
||||
(FunnelConfig, TFFunnelForMaskedLM),
|
||||
(MPNetConfig, TFMPNetForMaskedLM),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
("led", "TFLEDForConditionalGeneration"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("albert", "TFAlbertForMaskedLM"),
|
||||
("marian", "TFMarianMTModel"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("longformer", "TFLongformerForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("bert", "TFBertForMaskedLM"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("mobilebert", "TFMobileBertForMaskedLM"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
("electra", "TFElectraForMaskedLM"),
|
||||
("funnel", "TFFunnelForMaskedLM"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
(RemBertConfig, TFRemBertForCausalLM),
|
||||
(RoFormerConfig, TFRoFormerForCausalLM),
|
||||
(BertConfig, TFBertLMHeadModel),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||
(XLNetConfig, TFXLNetLMHeadModel),
|
||||
(
|
||||
XLMConfig,
|
||||
TFXLMWithLMHeadModel,
|
||||
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
|
||||
(CTRLConfig, TFCTRLLMHeadModel),
|
||||
("rembert", "TFRemBertForCausalLM"),
|
||||
("roformer", "TFRoFormerForCausalLM"),
|
||||
("bert", "TFBertLMHeadModel"),
|
||||
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
|
||||
("gpt2", "TFGPT2LMHeadModel"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("ctrl", "TFCTRLLMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
(RemBertConfig, TFRemBertForMaskedLM),
|
||||
(RoFormerConfig, TFRoFormerForMaskedLM),
|
||||
(ConvBertConfig, TFConvBertForMaskedLM),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
(CamembertConfig, TFCamembertForMaskedLM),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||
(LongformerConfig, TFLongformerForMaskedLM),
|
||||
(RobertaConfig, TFRobertaForMaskedLM),
|
||||
(LayoutLMConfig, TFLayoutLMForMaskedLM),
|
||||
(BertConfig, TFBertForMaskedLM),
|
||||
(MobileBertConfig, TFMobileBertForMaskedLM),
|
||||
(FlaubertConfig, TFFlaubertWithLMHeadModel),
|
||||
(XLMConfig, TFXLMWithLMHeadModel),
|
||||
(ElectraConfig, TFElectraForMaskedLM),
|
||||
(FunnelConfig, TFFunnelForMaskedLM),
|
||||
(MPNetConfig, TFMPNetForMaskedLM),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
("distilbert", "TFDistilBertForMaskedLM"),
|
||||
("albert", "TFAlbertForMaskedLM"),
|
||||
("camembert", "TFCamembertForMaskedLM"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("longformer", "TFLongformerForMaskedLM"),
|
||||
("roberta", "TFRobertaForMaskedLM"),
|
||||
("layoutlm", "TFLayoutLMForMaskedLM"),
|
||||
("bert", "TFBertForMaskedLM"),
|
||||
("mobilebert", "TFMobileBertForMaskedLM"),
|
||||
("flaubert", "TFFlaubertWithLMHeadModel"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("electra", "TFElectraForMaskedLM"),
|
||||
("funnel", "TFFunnelForMaskedLM"),
|
||||
("mpnet", "TFMPNetForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(LEDConfig, TFLEDForConditionalGeneration),
|
||||
(MT5Config, TFMT5ForConditionalGeneration),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(MarianConfig, TFMarianMTModel),
|
||||
(MBartConfig, TFMBartForConditionalGeneration),
|
||||
(PegasusConfig, TFPegasusForConditionalGeneration),
|
||||
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
|
||||
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
|
||||
(BartConfig, TFBartForConditionalGeneration),
|
||||
("led", "TFLEDForConditionalGeneration"),
|
||||
("mt5", "TFMT5ForConditionalGeneration"),
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("marian", "TFMarianMTModel"),
|
||||
("mbart", "TFMBartForConditionalGeneration"),
|
||||
("pegasus", "TFPegasusForConditionalGeneration"),
|
||||
("blenderbot", "TFBlenderbotForConditionalGeneration"),
|
||||
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
|
||||
("bart", "TFBartForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(RemBertConfig, TFRemBertForSequenceClassification),
|
||||
(RoFormerConfig, TFRoFormerForSequenceClassification),
|
||||
(ConvBertConfig, TFConvBertForSequenceClassification),
|
||||
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||
(AlbertConfig, TFAlbertForSequenceClassification),
|
||||
(CamembertConfig, TFCamembertForSequenceClassification),
|
||||
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
|
||||
(LongformerConfig, TFLongformerForSequenceClassification),
|
||||
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||
(LayoutLMConfig, TFLayoutLMForSequenceClassification),
|
||||
(BertConfig, TFBertForSequenceClassification),
|
||||
(XLNetConfig, TFXLNetForSequenceClassification),
|
||||
(MobileBertConfig, TFMobileBertForSequenceClassification),
|
||||
(FlaubertConfig, TFFlaubertForSequenceClassification),
|
||||
(XLMConfig, TFXLMForSequenceClassification),
|
||||
(ElectraConfig, TFElectraForSequenceClassification),
|
||||
(FunnelConfig, TFFunnelForSequenceClassification),
|
||||
(GPT2Config, TFGPT2ForSequenceClassification),
|
||||
(MPNetConfig, TFMPNetForSequenceClassification),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification),
|
||||
(TransfoXLConfig, TFTransfoXLForSequenceClassification),
|
||||
(CTRLConfig, TFCTRLForSequenceClassification),
|
||||
("rembert", "TFRemBertForSequenceClassification"),
|
||||
("roformer", "TFRoFormerForSequenceClassification"),
|
||||
("convbert", "TFConvBertForSequenceClassification"),
|
||||
("distilbert", "TFDistilBertForSequenceClassification"),
|
||||
("albert", "TFAlbertForSequenceClassification"),
|
||||
("camembert", "TFCamembertForSequenceClassification"),
|
||||
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
|
||||
("longformer", "TFLongformerForSequenceClassification"),
|
||||
("roberta", "TFRobertaForSequenceClassification"),
|
||||
("layoutlm", "TFLayoutLMForSequenceClassification"),
|
||||
("bert", "TFBertForSequenceClassification"),
|
||||
("xlnet", "TFXLNetForSequenceClassification"),
|
||||
("mobilebert", "TFMobileBertForSequenceClassification"),
|
||||
("flaubert", "TFFlaubertForSequenceClassification"),
|
||||
("xlm", "TFXLMForSequenceClassification"),
|
||||
("electra", "TFElectraForSequenceClassification"),
|
||||
("funnel", "TFFunnelForSequenceClassification"),
|
||||
("gpt2", "TFGPT2ForSequenceClassification"),
|
||||
("mpnet", "TFMPNetForSequenceClassification"),
|
||||
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
|
||||
("transfo-xl", "TFTransfoXLForSequenceClassification"),
|
||||
("ctrl", "TFCTRLForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(RemBertConfig, TFRemBertForQuestionAnswering),
|
||||
(RoFormerConfig, TFRoFormerForQuestionAnswering),
|
||||
(ConvBertConfig, TFConvBertForQuestionAnswering),
|
||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||
(AlbertConfig, TFAlbertForQuestionAnswering),
|
||||
(CamembertConfig, TFCamembertForQuestionAnswering),
|
||||
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
|
||||
(LongformerConfig, TFLongformerForQuestionAnswering),
|
||||
(RobertaConfig, TFRobertaForQuestionAnswering),
|
||||
(BertConfig, TFBertForQuestionAnswering),
|
||||
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
|
||||
(MobileBertConfig, TFMobileBertForQuestionAnswering),
|
||||
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
|
||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||
(ElectraConfig, TFElectraForQuestionAnswering),
|
||||
(FunnelConfig, TFFunnelForQuestionAnswering),
|
||||
(MPNetConfig, TFMPNetForQuestionAnswering),
|
||||
("rembert", "TFRemBertForQuestionAnswering"),
|
||||
("roformer", "TFRoFormerForQuestionAnswering"),
|
||||
("convbert", "TFConvBertForQuestionAnswering"),
|
||||
("distilbert", "TFDistilBertForQuestionAnswering"),
|
||||
("albert", "TFAlbertForQuestionAnswering"),
|
||||
("camembert", "TFCamembertForQuestionAnswering"),
|
||||
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
|
||||
("longformer", "TFLongformerForQuestionAnswering"),
|
||||
("roberta", "TFRobertaForQuestionAnswering"),
|
||||
("bert", "TFBertForQuestionAnswering"),
|
||||
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
|
||||
("mobilebert", "TFMobileBertForQuestionAnswering"),
|
||||
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
|
||||
("xlm", "TFXLMForQuestionAnsweringSimple"),
|
||||
("electra", "TFElectraForQuestionAnswering"),
|
||||
("funnel", "TFFunnelForQuestionAnswering"),
|
||||
("mpnet", "TFMPNetForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
(RemBertConfig, TFRemBertForTokenClassification),
|
||||
(RoFormerConfig, TFRoFormerForTokenClassification),
|
||||
(ConvBertConfig, TFConvBertForTokenClassification),
|
||||
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||
(AlbertConfig, TFAlbertForTokenClassification),
|
||||
(CamembertConfig, TFCamembertForTokenClassification),
|
||||
(FlaubertConfig, TFFlaubertForTokenClassification),
|
||||
(XLMConfig, TFXLMForTokenClassification),
|
||||
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
|
||||
(LongformerConfig, TFLongformerForTokenClassification),
|
||||
(RobertaConfig, TFRobertaForTokenClassification),
|
||||
(LayoutLMConfig, TFLayoutLMForTokenClassification),
|
||||
(BertConfig, TFBertForTokenClassification),
|
||||
(MobileBertConfig, TFMobileBertForTokenClassification),
|
||||
(XLNetConfig, TFXLNetForTokenClassification),
|
||||
(ElectraConfig, TFElectraForTokenClassification),
|
||||
(FunnelConfig, TFFunnelForTokenClassification),
|
||||
(MPNetConfig, TFMPNetForTokenClassification),
|
||||
("rembert", "TFRemBertForTokenClassification"),
|
||||
("roformer", "TFRoFormerForTokenClassification"),
|
||||
("convbert", "TFConvBertForTokenClassification"),
|
||||
("distilbert", "TFDistilBertForTokenClassification"),
|
||||
("albert", "TFAlbertForTokenClassification"),
|
||||
("camembert", "TFCamembertForTokenClassification"),
|
||||
("flaubert", "TFFlaubertForTokenClassification"),
|
||||
("xlm", "TFXLMForTokenClassification"),
|
||||
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
|
||||
("longformer", "TFLongformerForTokenClassification"),
|
||||
("roberta", "TFRobertaForTokenClassification"),
|
||||
("layoutlm", "TFLayoutLMForTokenClassification"),
|
||||
("bert", "TFBertForTokenClassification"),
|
||||
("mobilebert", "TFMobileBertForTokenClassification"),
|
||||
("xlnet", "TFXLNetForTokenClassification"),
|
||||
("electra", "TFElectraForTokenClassification"),
|
||||
("funnel", "TFFunnelForTokenClassification"),
|
||||
("mpnet", "TFMPNetForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
(RemBertConfig, TFRemBertForMultipleChoice),
|
||||
(RoFormerConfig, TFRoFormerForMultipleChoice),
|
||||
(ConvBertConfig, TFConvBertForMultipleChoice),
|
||||
(CamembertConfig, TFCamembertForMultipleChoice),
|
||||
(XLMConfig, TFXLMForMultipleChoice),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
|
||||
(LongformerConfig, TFLongformerForMultipleChoice),
|
||||
(RobertaConfig, TFRobertaForMultipleChoice),
|
||||
(BertConfig, TFBertForMultipleChoice),
|
||||
(DistilBertConfig, TFDistilBertForMultipleChoice),
|
||||
(MobileBertConfig, TFMobileBertForMultipleChoice),
|
||||
(XLNetConfig, TFXLNetForMultipleChoice),
|
||||
(FlaubertConfig, TFFlaubertForMultipleChoice),
|
||||
(AlbertConfig, TFAlbertForMultipleChoice),
|
||||
(ElectraConfig, TFElectraForMultipleChoice),
|
||||
(FunnelConfig, TFFunnelForMultipleChoice),
|
||||
(MPNetConfig, TFMPNetForMultipleChoice),
|
||||
("rembert", "TFRemBertForMultipleChoice"),
|
||||
("roformer", "TFRoFormerForMultipleChoice"),
|
||||
("convbert", "TFConvBertForMultipleChoice"),
|
||||
("camembert", "TFCamembertForMultipleChoice"),
|
||||
("xlm", "TFXLMForMultipleChoice"),
|
||||
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
|
||||
("longformer", "TFLongformerForMultipleChoice"),
|
||||
("roberta", "TFRobertaForMultipleChoice"),
|
||||
("bert", "TFBertForMultipleChoice"),
|
||||
("distilbert", "TFDistilBertForMultipleChoice"),
|
||||
("mobilebert", "TFMobileBertForMultipleChoice"),
|
||||
("xlnet", "TFXLNetForMultipleChoice"),
|
||||
("flaubert", "TFFlaubertForMultipleChoice"),
|
||||
("albert", "TFAlbertForMultipleChoice"),
|
||||
("electra", "TFElectraForMultipleChoice"),
|
||||
("funnel", "TFFunnelForMultipleChoice"),
|
||||
("mpnet", "TFMPNetForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
(BertConfig, TFBertForNextSentencePrediction),
|
||||
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
|
||||
("bert", "TFBertForNextSentencePrediction"),
|
||||
("mobilebert", "TFMobileBertForNextSentencePrediction"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModel(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_MAPPING
|
||||
|
||||
|
@ -14,12 +14,12 @@
|
||||
# limitations under the License.
|
||||
""" Auto Tokenizer class. """
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from ... import GPTNeoConfig
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import (
|
||||
cached_path,
|
||||
@ -29,315 +29,183 @@ from ...file_utils import (
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
from ..bart.tokenization_bart import BartTokenizer
|
||||
from ..bert.tokenization_bert import BertTokenizer
|
||||
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from ..bertweet.tokenization_bertweet import BertweetTokenizer
|
||||
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
|
||||
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
|
||||
from ..byt5.tokenization_byt5 import ByT5Tokenizer
|
||||
from ..canine.tokenization_canine import CanineTokenizer
|
||||
from ..convbert.tokenization_convbert import ConvBertTokenizer
|
||||
from ..ctrl.tokenization_ctrl import CTRLTokenizer
|
||||
from ..deberta.tokenization_deberta import DebertaTokenizer
|
||||
from ..distilbert.tokenization_distilbert import DistilBertTokenizer
|
||||
from ..dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from ..electra.tokenization_electra import ElectraTokenizer
|
||||
from ..flaubert.tokenization_flaubert import FlaubertTokenizer
|
||||
from ..fsmt.tokenization_fsmt import FSMTTokenizer
|
||||
from ..funnel.tokenization_funnel import FunnelTokenizer
|
||||
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from ..herbert.tokenization_herbert import HerbertTokenizer
|
||||
from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer
|
||||
from ..led.tokenization_led import LEDTokenizer
|
||||
from ..longformer.tokenization_longformer import LongformerTokenizer
|
||||
from ..luke.tokenization_luke import LukeTokenizer
|
||||
from ..lxmert.tokenization_lxmert import LxmertTokenizer
|
||||
from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer
|
||||
from ..mpnet.tokenization_mpnet import MPNetTokenizer
|
||||
from ..openai.tokenization_openai import OpenAIGPTTokenizer
|
||||
from ..phobert.tokenization_phobert import PhobertTokenizer
|
||||
from ..prophetnet.tokenization_prophetnet import ProphetNetTokenizer
|
||||
from ..rag.tokenization_rag import RagTokenizer
|
||||
from ..retribert.tokenization_retribert import RetriBertTokenizer
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
from ..roformer.tokenization_roformer import RoFormerTokenizer
|
||||
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
|
||||
from ..tapas.tokenization_tapas import TapasTokenizer
|
||||
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
from ..xlm.tokenization_xlm import XLMTokenizer
|
||||
from ..encoder_decoder import EncoderDecoderConfig
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
AlbertConfig,
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
BigBirdConfig,
|
||||
BigBirdPegasusConfig,
|
||||
BlenderbotConfig,
|
||||
BlenderbotSmallConfig,
|
||||
CamembertConfig,
|
||||
CanineConfig,
|
||||
ConvBertConfig,
|
||||
CTRLConfig,
|
||||
DebertaConfig,
|
||||
DebertaV2Config,
|
||||
DistilBertConfig,
|
||||
DPRConfig,
|
||||
ElectraConfig,
|
||||
EncoderDecoderConfig,
|
||||
FlaubertConfig,
|
||||
FSMTConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
HubertConfig,
|
||||
IBertConfig,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
LukeConfig,
|
||||
LxmertConfig,
|
||||
M2M100Config,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
MPNetConfig,
|
||||
MT5Config,
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
ProphetNetConfig,
|
||||
RagConfig,
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
RoFormerConfig,
|
||||
Speech2TextConfig,
|
||||
SqueezeBertConfig,
|
||||
T5Config,
|
||||
TapasConfig,
|
||||
TransfoXLConfig,
|
||||
Wav2Vec2Config,
|
||||
XLMConfig,
|
||||
XLMProphetNetConfig,
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
config_class_to_model_type,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from ..albert.tokenization_albert import AlbertTokenizer
|
||||
from ..barthez.tokenization_barthez import BarthezTokenizer
|
||||
from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer
|
||||
from ..big_bird.tokenization_big_bird import BigBirdTokenizer
|
||||
from ..camembert.tokenization_camembert import CamembertTokenizer
|
||||
from ..cpm.tokenization_cpm import CpmTokenizer
|
||||
from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer
|
||||
from ..m2m_100 import M2M100Tokenizer
|
||||
from ..marian.tokenization_marian import MarianTokenizer
|
||||
from ..mbart.tokenization_mbart import MBartTokenizer
|
||||
from ..mbart.tokenization_mbart50 import MBart50Tokenizer
|
||||
from ..mt5 import MT5Tokenizer
|
||||
from ..pegasus.tokenization_pegasus import PegasusTokenizer
|
||||
from ..reformer.tokenization_reformer import ReformerTokenizer
|
||||
from ..speech_to_text import Speech2TextTokenizer
|
||||
from ..t5.tokenization_t5 import T5Tokenizer
|
||||
from ..xlm_prophetnet.tokenization_xlm_prophetnet import XLMProphetNetTokenizer
|
||||
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from ..xlnet.tokenization_xlnet import XLNetTokenizer
|
||||
else:
|
||||
AlbertTokenizer = None
|
||||
BarthezTokenizer = None
|
||||
BertGenerationTokenizer = None
|
||||
BigBirdTokenizer = None
|
||||
CamembertTokenizer = None
|
||||
CpmTokenizer = None
|
||||
DebertaV2Tokenizer = None
|
||||
MarianTokenizer = None
|
||||
MBartTokenizer = None
|
||||
MBart50Tokenizer = None
|
||||
MT5Tokenizer = None
|
||||
PegasusTokenizer = None
|
||||
ReformerTokenizer = None
|
||||
T5Tokenizer = None
|
||||
XLMRobertaTokenizer = None
|
||||
XLNetTokenizer = None
|
||||
XLMProphetNetTokenizer = None
|
||||
M2M100Tokenizer = None
|
||||
Speech2TextTokenizer = None
|
||||
|
||||
if is_tokenizers_available():
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ..albert.tokenization_albert_fast import AlbertTokenizerFast
|
||||
from ..bart.tokenization_bart_fast import BartTokenizerFast
|
||||
from ..barthez.tokenization_barthez_fast import BarthezTokenizerFast
|
||||
from ..bert.tokenization_bert_fast import BertTokenizerFast
|
||||
from ..big_bird.tokenization_big_bird_fast import BigBirdTokenizerFast
|
||||
from ..camembert.tokenization_camembert_fast import CamembertTokenizerFast
|
||||
from ..convbert.tokenization_convbert_fast import ConvBertTokenizerFast
|
||||
from ..cpm.tokenization_cpm_fast import CpmTokenizerFast
|
||||
from ..deberta.tokenization_deberta_fast import DebertaTokenizerFast
|
||||
from ..distilbert.tokenization_distilbert_fast import DistilBertTokenizerFast
|
||||
from ..dpr.tokenization_dpr_fast import DPRQuestionEncoderTokenizerFast
|
||||
from ..electra.tokenization_electra_fast import ElectraTokenizerFast
|
||||
from ..funnel.tokenization_funnel_fast import FunnelTokenizerFast
|
||||
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
|
||||
from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast
|
||||
from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast
|
||||
from ..led.tokenization_led_fast import LEDTokenizerFast
|
||||
from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast
|
||||
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
|
||||
from ..mbart.tokenization_mbart50_fast import MBart50TokenizerFast
|
||||
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
|
||||
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
|
||||
from ..mpnet.tokenization_mpnet_fast import MPNetTokenizerFast
|
||||
from ..mt5 import MT5TokenizerFast
|
||||
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
|
||||
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
|
||||
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
|
||||
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
|
||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
|
||||
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
|
||||
from ..t5.tokenization_t5_fast import T5TokenizerFast
|
||||
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||
from ..xlnet.tokenization_xlnet_fast import XLNetTokenizerFast
|
||||
|
||||
else:
|
||||
AlbertTokenizerFast = None
|
||||
BartTokenizerFast = None
|
||||
BarthezTokenizerFast = None
|
||||
BertTokenizerFast = None
|
||||
BigBirdTokenizerFast = None
|
||||
CamembertTokenizerFast = None
|
||||
ConvBertTokenizerFast = None
|
||||
CpmTokenizerFast = None
|
||||
DebertaTokenizerFast = None
|
||||
DistilBertTokenizerFast = None
|
||||
DPRQuestionEncoderTokenizerFast = None
|
||||
ElectraTokenizerFast = None
|
||||
FunnelTokenizerFast = None
|
||||
GPT2TokenizerFast = None
|
||||
HerbertTokenizerFast = None
|
||||
LayoutLMTokenizerFast = None
|
||||
LEDTokenizerFast = None
|
||||
LongformerTokenizerFast = None
|
||||
LxmertTokenizerFast = None
|
||||
MBartTokenizerFast = None
|
||||
MBart50TokenizerFast = None
|
||||
MobileBertTokenizerFast = None
|
||||
MPNetTokenizerFast = None
|
||||
MT5TokenizerFast = None
|
||||
OpenAIGPTTokenizerFast = None
|
||||
PegasusTokenizerFast = None
|
||||
ReformerTokenizerFast = None
|
||||
RetriBertTokenizerFast = None
|
||||
RobertaTokenizerFast = None
|
||||
RoFormerTokenizerFast = None
|
||||
SqueezeBertTokenizerFast = None
|
||||
T5TokenizerFast = None
|
||||
XLMRobertaTokenizerFast = None
|
||||
XLNetTokenizerFast = None
|
||||
PreTrainedTokenizerFast = None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
TOKENIZER_MAPPING = OrderedDict(
|
||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||
(RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
|
||||
(T5Config, (T5Tokenizer, T5TokenizerFast)),
|
||||
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
|
||||
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
||||
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
|
||||
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
|
||||
(CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)),
|
||||
(PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
|
||||
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
|
||||
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
|
||||
(MarianConfig, (MarianTokenizer, None)),
|
||||
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
|
||||
(BlenderbotConfig, (BlenderbotTokenizer, None)),
|
||||
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
|
||||
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
|
||||
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
|
||||
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
|
||||
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
|
||||
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
|
||||
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
|
||||
(BertConfig, (BertTokenizer, BertTokenizerFast)),
|
||||
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
|
||||
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
|
||||
(TransfoXLConfig, (TransfoXLTokenizer, None)),
|
||||
(XLNetConfig, (XLNetTokenizer, XLNetTokenizerFast)),
|
||||
(FlaubertConfig, (FlaubertTokenizer, None)),
|
||||
(XLMConfig, (XLMTokenizer, None)),
|
||||
(CTRLConfig, (CTRLTokenizer, None)),
|
||||
(FSMTConfig, (FSMTTokenizer, None)),
|
||||
(BertGenerationConfig, (BertGenerationTokenizer, None)),
|
||||
(DebertaConfig, (DebertaTokenizer, DebertaTokenizerFast)),
|
||||
(DebertaV2Config, (DebertaV2Tokenizer, None)),
|
||||
(RagConfig, (RagTokenizer, None)),
|
||||
(XLMProphetNetConfig, (XLMProphetNetTokenizer, None)),
|
||||
(Speech2TextConfig, (Speech2TextTokenizer, None)),
|
||||
(M2M100Config, (M2M100Tokenizer, None)),
|
||||
(ProphetNetConfig, (ProphetNetTokenizer, None)),
|
||||
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
|
||||
(TapasConfig, (TapasTokenizer, None)),
|
||||
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
|
||||
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
|
||||
(BigBirdConfig, (BigBirdTokenizer, BigBirdTokenizerFast)),
|
||||
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
|
||||
(HubertConfig, (Wav2Vec2CTCTokenizer, None)),
|
||||
(GPTNeoConfig, (GPT2Tokenizer, GPT2TokenizerFast)),
|
||||
(LukeConfig, (LukeTokenizer, None)),
|
||||
(BigBirdPegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
|
||||
(CanineConfig, (CanineTokenizer, None)),
|
||||
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"t5",
|
||||
(
|
||||
"T5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mt5",
|
||||
(
|
||||
"MT5Tokenizer" if is_sentencepiece_available() else None,
|
||||
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"albert",
|
||||
(
|
||||
"AlbertTokenizer" if is_sentencepiece_available() else None,
|
||||
"AlbertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"camembert",
|
||||
(
|
||||
"CamembertTokenizer" if is_sentencepiece_available() else None,
|
||||
"CamembertTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"pegasus",
|
||||
(
|
||||
"PegasusTokenizer" if is_sentencepiece_available() else None,
|
||||
"PegasusTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mbart",
|
||||
(
|
||||
"MBartTokenizer" if is_sentencepiece_available() else None,
|
||||
"MBartTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"xlm-roberta",
|
||||
(
|
||||
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
|
||||
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
|
||||
("blenderbot", ("BlenderbotTokenizer", None)),
|
||||
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"reformer",
|
||||
(
|
||||
"ReformerTokenizer" if is_sentencepiece_available() else None,
|
||||
"ReformerTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"dpr",
|
||||
("DPRQuestionEncoderTokenizer", "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
(
|
||||
"xlnet",
|
||||
(
|
||||
"XLNetTokenizer" if is_sentencepiece_available() else None,
|
||||
"XLNetTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("flaubert", ("FlaubertTokenizer", None)),
|
||||
("xlm", ("XLMTokenizer", None)),
|
||||
("ctrl", ("CTRLTokenizer", None)),
|
||||
("fsmt", ("FSMTTokenizer", None)),
|
||||
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("rag", ("RagTokenizer", None)),
|
||||
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("prophetnet", ("ProphetNetTokenizer", None)),
|
||||
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"big_bird",
|
||||
(
|
||||
"BigBirdTokenizer" if is_sentencepiece_available() else None,
|
||||
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("luke", ("LukeTokenizer", None)),
|
||||
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("canine", ("CanineTokenizer", None)),
|
||||
("bertweet", ("BertweetTokenizer", None)),
|
||||
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
||||
("byt5", ("ByT5Tokenizer", None)),
|
||||
(
|
||||
"cpm",
|
||||
(
|
||||
"CpmTokenizer" if is_sentencepiece_available() else None,
|
||||
"CpmTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("phobert", ("PhobertTokenizer", None)),
|
||||
(
|
||||
"barthez",
|
||||
(
|
||||
"BarthezTokenizer" if is_sentencepiece_available() else None,
|
||||
"BarthezTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"mbart50",
|
||||
(
|
||||
"MBart50Tokenizer" if is_sentencepiece_available() else None,
|
||||
"MBart50TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# For tokenizers which are not directly mapped from a config
|
||||
NO_CONFIG_TOKENIZER = [
|
||||
BertJapaneseTokenizer,
|
||||
BertweetTokenizer,
|
||||
ByT5Tokenizer,
|
||||
CpmTokenizer,
|
||||
CpmTokenizerFast,
|
||||
HerbertTokenizer,
|
||||
HerbertTokenizerFast,
|
||||
PhobertTokenizer,
|
||||
BarthezTokenizer,
|
||||
BarthezTokenizerFast,
|
||||
MBart50Tokenizer,
|
||||
MBart50TokenizerFast,
|
||||
PreTrainedTokenizerFast,
|
||||
]
|
||||
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
||||
|
||||
|
||||
SLOW_TOKENIZER_MAPPING = {
|
||||
k: (v[0] if v[0] is not None else v[1])
|
||||
for k, v in TOKENIZER_MAPPING.items()
|
||||
if (v[0] is not None or v[1] is not None)
|
||||
}
|
||||
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
|
||||
|
||||
|
||||
def tokenizer_class_from_name(class_name: str):
|
||||
all_tokenizer_classes = (
|
||||
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
|
||||
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
|
||||
+ [v for v in NO_CONFIG_TOKENIZER if v is not None]
|
||||
)
|
||||
for c in all_tokenizer_classes:
|
||||
if c.__name__ == class_name:
|
||||
return c
|
||||
if class_name == "PreTrainedTokenizerFast":
|
||||
return PreTrainedTokenizerFast
|
||||
|
||||
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
|
||||
if class_name in tokenizers:
|
||||
break
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def get_tokenizer_config(
|
||||
@ -454,7 +322,7 @@ class AutoTokenizer:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(SLOW_TOKENIZER_MAPPING)
|
||||
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
r"""
|
||||
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
|
||||
@ -565,7 +433,8 @@ class AutoTokenizer:
|
||||
)
|
||||
config = config.encoder
|
||||
|
||||
if type(config) in TOKENIZER_MAPPING.keys():
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
if model_type is not None:
|
||||
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
|
||||
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
|
||||
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
@ -33,10 +33,8 @@ _import_structure = {
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
|
||||
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
|
||||
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
@ -72,10 +70,8 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_mbart import MBartTokenizer
|
||||
from .tokenization_mbart50 import MBart50Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_mbart50_fast import MBart50TokenizerFast
|
||||
from .tokenization_mbart_fast import MBartTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
|
42
src/transformers/models/mbart50/__init__.py
Normal file
42
src/transformers/models/mbart50/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
|
||||
|
||||
|
||||
_import_structure = {}
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_mbart50 import MBart50Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_mbart50_fast import MBart50TokenizerFast
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
@ -1781,25 +1781,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if config_tokenizer_class is None:
|
||||
# Third attempt. If we have not yet found the original type of the tokenizer,
|
||||
# we are loading we see if we can infer it from the type of the configuration file
|
||||
from .models.auto.configuration_auto import CONFIG_MAPPING # tests_ignore
|
||||
from .models.auto.tokenization_auto import TOKENIZER_MAPPING # tests_ignore
|
||||
from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore
|
||||
|
||||
if hasattr(config, "model_type"):
|
||||
config_class = CONFIG_MAPPING.get(config.model_type)
|
||||
model_type = config.model_type
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
config_class = None
|
||||
for pattern, config_class_tmp in CONFIG_MAPPING.items():
|
||||
model_type = None
|
||||
for pattern in TOKENIZER_MAPPING_NAMES.keys():
|
||||
if pattern in str(pretrained_model_name_or_path):
|
||||
config_class = config_class_tmp
|
||||
model_type = pattern
|
||||
break
|
||||
|
||||
if config_class in TOKENIZER_MAPPING.keys():
|
||||
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
|
||||
if config_tokenizer_class is not None:
|
||||
config_tokenizer_class = config_tokenizer_class.__name__
|
||||
else:
|
||||
config_tokenizer_class = config_tokenizer_class_fast.__name__
|
||||
if model_type is not None:
|
||||
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES[model_type]
|
||||
if config_tokenizer_class is None:
|
||||
config_tokenizer_class = config_tokenizer_class_fast
|
||||
|
||||
if config_tokenizer_class is not None:
|
||||
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
|
||||
|
@ -74,6 +74,7 @@ from .file_utils import (
|
||||
)
|
||||
from .modelcard import TrainingSummary
|
||||
from .modeling_utils import PreTrainedModel, unwrap_model
|
||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
from .optimization import Adafactor, AdamW, get_scheduler
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_callback import (
|
||||
@ -125,7 +126,6 @@ from .trainer_utils import (
|
||||
)
|
||||
from .training_args import ParallelMode, TrainingArguments
|
||||
from .utils import logging
|
||||
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
|
||||
|
||||
_is_torch_generator_available = False
|
||||
|
@ -191,7 +191,7 @@ class LxmertTokenizerFast:
|
||||
requires_backends(cls, ["tokenizers"])
|
||||
|
||||
|
||||
class MBart50TokenizerFast:
|
||||
class MBartTokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
@ -200,7 +200,7 @@ class MBart50TokenizerFast:
|
||||
requires_backends(cls, ["tokenizers"])
|
||||
|
||||
|
||||
class MBartTokenizerFast:
|
||||
class MBart50TokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
@ -1,374 +0,0 @@
|
||||
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
||||
# 1. modify: models/auto/modeling_auto.py
|
||||
# 2. run: python utils/class_mapping_update.py
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForQuestionAnswering"),
|
||||
("CanineConfig", "CanineForQuestionAnswering"),
|
||||
("RoFormerConfig", "RoFormerForQuestionAnswering"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"),
|
||||
("BigBirdConfig", "BigBirdForQuestionAnswering"),
|
||||
("ConvBertConfig", "ConvBertForQuestionAnswering"),
|
||||
("LEDConfig", "LEDForQuestionAnswering"),
|
||||
("DistilBertConfig", "DistilBertForQuestionAnswering"),
|
||||
("AlbertConfig", "AlbertForQuestionAnswering"),
|
||||
("CamembertConfig", "CamembertForQuestionAnswering"),
|
||||
("BartConfig", "BartForQuestionAnswering"),
|
||||
("MBartConfig", "MBartForQuestionAnswering"),
|
||||
("LongformerConfig", "LongformerForQuestionAnswering"),
|
||||
("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"),
|
||||
("RobertaConfig", "RobertaForQuestionAnswering"),
|
||||
("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"),
|
||||
("BertConfig", "BertForQuestionAnswering"),
|
||||
("XLNetConfig", "XLNetForQuestionAnsweringSimple"),
|
||||
("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"),
|
||||
("MegatronBertConfig", "MegatronBertForQuestionAnswering"),
|
||||
("MobileBertConfig", "MobileBertForQuestionAnswering"),
|
||||
("XLMConfig", "XLMForQuestionAnsweringSimple"),
|
||||
("ElectraConfig", "ElectraForQuestionAnswering"),
|
||||
("ReformerConfig", "ReformerForQuestionAnswering"),
|
||||
("FunnelConfig", "FunnelForQuestionAnswering"),
|
||||
("LxmertConfig", "LxmertForQuestionAnswering"),
|
||||
("MPNetConfig", "MPNetForQuestionAnswering"),
|
||||
("DebertaConfig", "DebertaForQuestionAnswering"),
|
||||
("DebertaV2Config", "DebertaV2ForQuestionAnswering"),
|
||||
("IBertConfig", "IBertForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForCausalLM"),
|
||||
("RoFormerConfig", "RoFormerForCausalLM"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"),
|
||||
("GPTNeoConfig", "GPTNeoForCausalLM"),
|
||||
("BigBirdConfig", "BigBirdForCausalLM"),
|
||||
("CamembertConfig", "CamembertForCausalLM"),
|
||||
("XLMRobertaConfig", "XLMRobertaForCausalLM"),
|
||||
("RobertaConfig", "RobertaForCausalLM"),
|
||||
("BertConfig", "BertLMHeadModel"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
|
||||
("GPT2Config", "GPT2LMHeadModel"),
|
||||
("TransfoXLConfig", "TransfoXLLMHeadModel"),
|
||||
("XLNetConfig", "XLNetLMHeadModel"),
|
||||
("XLMConfig", "XLMWithLMHeadModel"),
|
||||
("CTRLConfig", "CTRLLMHeadModel"),
|
||||
("ReformerConfig", "ReformerModelWithLMHead"),
|
||||
("BertGenerationConfig", "BertGenerationDecoder"),
|
||||
("XLMProphetNetConfig", "XLMProphetNetForCausalLM"),
|
||||
("ProphetNetConfig", "ProphetNetForCausalLM"),
|
||||
("BartConfig", "BartForCausalLM"),
|
||||
("MBartConfig", "MBartForCausalLM"),
|
||||
("PegasusConfig", "PegasusForCausalLM"),
|
||||
("MarianConfig", "MarianForCausalLM"),
|
||||
("BlenderbotConfig", "BlenderbotForCausalLM"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"),
|
||||
("MegatronBertConfig", "MegatronBertForCausalLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("ViTConfig", "ViTForImageClassification"),
|
||||
("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"),
|
||||
("BeitConfig", "BeitForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForMaskedLM"),
|
||||
("RoFormerConfig", "RoFormerForMaskedLM"),
|
||||
("BigBirdConfig", "BigBirdForMaskedLM"),
|
||||
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
|
||||
("ConvBertConfig", "ConvBertForMaskedLM"),
|
||||
("LayoutLMConfig", "LayoutLMForMaskedLM"),
|
||||
("DistilBertConfig", "DistilBertForMaskedLM"),
|
||||
("AlbertConfig", "AlbertForMaskedLM"),
|
||||
("BartConfig", "BartForConditionalGeneration"),
|
||||
("MBartConfig", "MBartForConditionalGeneration"),
|
||||
("CamembertConfig", "CamembertForMaskedLM"),
|
||||
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
|
||||
("LongformerConfig", "LongformerForMaskedLM"),
|
||||
("RobertaConfig", "RobertaForMaskedLM"),
|
||||
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
|
||||
("BertConfig", "BertForMaskedLM"),
|
||||
("MegatronBertConfig", "MegatronBertForMaskedLM"),
|
||||
("MobileBertConfig", "MobileBertForMaskedLM"),
|
||||
("FlaubertConfig", "FlaubertWithLMHeadModel"),
|
||||
("XLMConfig", "XLMWithLMHeadModel"),
|
||||
("ElectraConfig", "ElectraForMaskedLM"),
|
||||
("ReformerConfig", "ReformerForMaskedLM"),
|
||||
("FunnelConfig", "FunnelForMaskedLM"),
|
||||
("MPNetConfig", "MPNetForMaskedLM"),
|
||||
("TapasConfig", "TapasForMaskedLM"),
|
||||
("DebertaConfig", "DebertaForMaskedLM"),
|
||||
("DebertaV2Config", "DebertaV2ForMaskedLM"),
|
||||
("IBertConfig", "IBertForMaskedLM"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForMultipleChoice"),
|
||||
("CanineConfig", "CanineForMultipleChoice"),
|
||||
("RoFormerConfig", "RoFormerForMultipleChoice"),
|
||||
("BigBirdConfig", "BigBirdForMultipleChoice"),
|
||||
("ConvBertConfig", "ConvBertForMultipleChoice"),
|
||||
("CamembertConfig", "CamembertForMultipleChoice"),
|
||||
("ElectraConfig", "ElectraForMultipleChoice"),
|
||||
("XLMRobertaConfig", "XLMRobertaForMultipleChoice"),
|
||||
("LongformerConfig", "LongformerForMultipleChoice"),
|
||||
("RobertaConfig", "RobertaForMultipleChoice"),
|
||||
("SqueezeBertConfig", "SqueezeBertForMultipleChoice"),
|
||||
("BertConfig", "BertForMultipleChoice"),
|
||||
("DistilBertConfig", "DistilBertForMultipleChoice"),
|
||||
("MegatronBertConfig", "MegatronBertForMultipleChoice"),
|
||||
("MobileBertConfig", "MobileBertForMultipleChoice"),
|
||||
("XLNetConfig", "XLNetForMultipleChoice"),
|
||||
("AlbertConfig", "AlbertForMultipleChoice"),
|
||||
("XLMConfig", "XLMForMultipleChoice"),
|
||||
("FlaubertConfig", "FlaubertForMultipleChoice"),
|
||||
("FunnelConfig", "FunnelForMultipleChoice"),
|
||||
("MPNetConfig", "MPNetForMultipleChoice"),
|
||||
("IBertConfig", "IBertForMultipleChoice"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("BertConfig", "BertForNextSentencePrediction"),
|
||||
("MegatronBertConfig", "MegatronBertForNextSentencePrediction"),
|
||||
("MobileBertConfig", "MobileBertForNextSentencePrediction"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("DetrConfig", "DetrForObjectDetection"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
|
||||
("M2M100Config", "M2M100ForConditionalGeneration"),
|
||||
("LEDConfig", "LEDForConditionalGeneration"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
|
||||
("MT5Config", "MT5ForConditionalGeneration"),
|
||||
("T5Config", "T5ForConditionalGeneration"),
|
||||
("PegasusConfig", "PegasusForConditionalGeneration"),
|
||||
("MarianConfig", "MarianMTModel"),
|
||||
("MBartConfig", "MBartForConditionalGeneration"),
|
||||
("BlenderbotConfig", "BlenderbotForConditionalGeneration"),
|
||||
("BartConfig", "BartForConditionalGeneration"),
|
||||
("FSMTConfig", "FSMTForConditionalGeneration"),
|
||||
("EncoderDecoderConfig", "EncoderDecoderModel"),
|
||||
("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"),
|
||||
("ProphetNetConfig", "ProphetNetForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForSequenceClassification"),
|
||||
("CanineConfig", "CanineForSequenceClassification"),
|
||||
("RoFormerConfig", "RoFormerForSequenceClassification"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"),
|
||||
("BigBirdConfig", "BigBirdForSequenceClassification"),
|
||||
("ConvBertConfig", "ConvBertForSequenceClassification"),
|
||||
("LEDConfig", "LEDForSequenceClassification"),
|
||||
("DistilBertConfig", "DistilBertForSequenceClassification"),
|
||||
("AlbertConfig", "AlbertForSequenceClassification"),
|
||||
("CamembertConfig", "CamembertForSequenceClassification"),
|
||||
("XLMRobertaConfig", "XLMRobertaForSequenceClassification"),
|
||||
("MBartConfig", "MBartForSequenceClassification"),
|
||||
("BartConfig", "BartForSequenceClassification"),
|
||||
("LongformerConfig", "LongformerForSequenceClassification"),
|
||||
("RobertaConfig", "RobertaForSequenceClassification"),
|
||||
("SqueezeBertConfig", "SqueezeBertForSequenceClassification"),
|
||||
("LayoutLMConfig", "LayoutLMForSequenceClassification"),
|
||||
("BertConfig", "BertForSequenceClassification"),
|
||||
("XLNetConfig", "XLNetForSequenceClassification"),
|
||||
("MegatronBertConfig", "MegatronBertForSequenceClassification"),
|
||||
("MobileBertConfig", "MobileBertForSequenceClassification"),
|
||||
("FlaubertConfig", "FlaubertForSequenceClassification"),
|
||||
("XLMConfig", "XLMForSequenceClassification"),
|
||||
("ElectraConfig", "ElectraForSequenceClassification"),
|
||||
("FunnelConfig", "FunnelForSequenceClassification"),
|
||||
("DebertaConfig", "DebertaForSequenceClassification"),
|
||||
("DebertaV2Config", "DebertaV2ForSequenceClassification"),
|
||||
("GPT2Config", "GPT2ForSequenceClassification"),
|
||||
("GPTNeoConfig", "GPTNeoForSequenceClassification"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"),
|
||||
("ReformerConfig", "ReformerForSequenceClassification"),
|
||||
("CTRLConfig", "CTRLForSequenceClassification"),
|
||||
("TransfoXLConfig", "TransfoXLForSequenceClassification"),
|
||||
("MPNetConfig", "MPNetForSequenceClassification"),
|
||||
("TapasConfig", "TapasForSequenceClassification"),
|
||||
("IBertConfig", "IBertForSequenceClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("TapasConfig", "TapasForQuestionAnswering"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForTokenClassification"),
|
||||
("CanineConfig", "CanineForTokenClassification"),
|
||||
("RoFormerConfig", "RoFormerForTokenClassification"),
|
||||
("BigBirdConfig", "BigBirdForTokenClassification"),
|
||||
("ConvBertConfig", "ConvBertForTokenClassification"),
|
||||
("LayoutLMConfig", "LayoutLMForTokenClassification"),
|
||||
("DistilBertConfig", "DistilBertForTokenClassification"),
|
||||
("CamembertConfig", "CamembertForTokenClassification"),
|
||||
("FlaubertConfig", "FlaubertForTokenClassification"),
|
||||
("XLMConfig", "XLMForTokenClassification"),
|
||||
("XLMRobertaConfig", "XLMRobertaForTokenClassification"),
|
||||
("LongformerConfig", "LongformerForTokenClassification"),
|
||||
("RobertaConfig", "RobertaForTokenClassification"),
|
||||
("SqueezeBertConfig", "SqueezeBertForTokenClassification"),
|
||||
("BertConfig", "BertForTokenClassification"),
|
||||
("MegatronBertConfig", "MegatronBertForTokenClassification"),
|
||||
("MobileBertConfig", "MobileBertForTokenClassification"),
|
||||
("XLNetConfig", "XLNetForTokenClassification"),
|
||||
("AlbertConfig", "AlbertForTokenClassification"),
|
||||
("ElectraConfig", "ElectraForTokenClassification"),
|
||||
("FunnelConfig", "FunnelForTokenClassification"),
|
||||
("MPNetConfig", "MPNetForTokenClassification"),
|
||||
("DebertaConfig", "DebertaForTokenClassification"),
|
||||
("DebertaV2Config", "DebertaV2ForTokenClassification"),
|
||||
("IBertConfig", "IBertForTokenClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("BeitConfig", "BeitModel"),
|
||||
("RemBertConfig", "RemBertModel"),
|
||||
("VisualBertConfig", "VisualBertModel"),
|
||||
("CanineConfig", "CanineModel"),
|
||||
("RoFormerConfig", "RoFormerModel"),
|
||||
("CLIPConfig", "CLIPModel"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusModel"),
|
||||
("DeiTConfig", "DeiTModel"),
|
||||
("LukeConfig", "LukeModel"),
|
||||
("DetrConfig", "DetrModel"),
|
||||
("GPTNeoConfig", "GPTNeoModel"),
|
||||
("BigBirdConfig", "BigBirdModel"),
|
||||
("Speech2TextConfig", "Speech2TextModel"),
|
||||
("ViTConfig", "ViTModel"),
|
||||
("Wav2Vec2Config", "Wav2Vec2Model"),
|
||||
("HubertConfig", "HubertModel"),
|
||||
("M2M100Config", "M2M100Model"),
|
||||
("ConvBertConfig", "ConvBertModel"),
|
||||
("LEDConfig", "LEDModel"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallModel"),
|
||||
("RetriBertConfig", "RetriBertModel"),
|
||||
("MT5Config", "MT5Model"),
|
||||
("T5Config", "T5Model"),
|
||||
("PegasusConfig", "PegasusModel"),
|
||||
("MarianConfig", "MarianModel"),
|
||||
("MBartConfig", "MBartModel"),
|
||||
("BlenderbotConfig", "BlenderbotModel"),
|
||||
("DistilBertConfig", "DistilBertModel"),
|
||||
("AlbertConfig", "AlbertModel"),
|
||||
("CamembertConfig", "CamembertModel"),
|
||||
("XLMRobertaConfig", "XLMRobertaModel"),
|
||||
("BartConfig", "BartModel"),
|
||||
("LongformerConfig", "LongformerModel"),
|
||||
("RobertaConfig", "RobertaModel"),
|
||||
("LayoutLMConfig", "LayoutLMModel"),
|
||||
("SqueezeBertConfig", "SqueezeBertModel"),
|
||||
("BertConfig", "BertModel"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTModel"),
|
||||
("GPT2Config", "GPT2Model"),
|
||||
("MegatronBertConfig", "MegatronBertModel"),
|
||||
("MobileBertConfig", "MobileBertModel"),
|
||||
("TransfoXLConfig", "TransfoXLModel"),
|
||||
("XLNetConfig", "XLNetModel"),
|
||||
("FlaubertConfig", "FlaubertModel"),
|
||||
("FSMTConfig", "FSMTModel"),
|
||||
("XLMConfig", "XLMModel"),
|
||||
("CTRLConfig", "CTRLModel"),
|
||||
("ElectraConfig", "ElectraModel"),
|
||||
("ReformerConfig", "ReformerModel"),
|
||||
("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"),
|
||||
("LxmertConfig", "LxmertModel"),
|
||||
("BertGenerationConfig", "BertGenerationEncoder"),
|
||||
("DebertaConfig", "DebertaModel"),
|
||||
("DebertaV2Config", "DebertaV2Model"),
|
||||
("DPRConfig", "DPRQuestionEncoder"),
|
||||
("XLMProphetNetConfig", "XLMProphetNetModel"),
|
||||
("ProphetNetConfig", "ProphetNetModel"),
|
||||
("MPNetConfig", "MPNetModel"),
|
||||
("TapasConfig", "TapasModel"),
|
||||
("IBertConfig", "IBertModel"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RemBertConfig", "RemBertForMaskedLM"),
|
||||
("RoFormerConfig", "RoFormerForMaskedLM"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
|
||||
("GPTNeoConfig", "GPTNeoForCausalLM"),
|
||||
("BigBirdConfig", "BigBirdForMaskedLM"),
|
||||
("Speech2TextConfig", "Speech2TextForConditionalGeneration"),
|
||||
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
|
||||
("M2M100Config", "M2M100ForConditionalGeneration"),
|
||||
("ConvBertConfig", "ConvBertForMaskedLM"),
|
||||
("LEDConfig", "LEDForConditionalGeneration"),
|
||||
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
|
||||
("LayoutLMConfig", "LayoutLMForMaskedLM"),
|
||||
("T5Config", "T5ForConditionalGeneration"),
|
||||
("DistilBertConfig", "DistilBertForMaskedLM"),
|
||||
("AlbertConfig", "AlbertForMaskedLM"),
|
||||
("CamembertConfig", "CamembertForMaskedLM"),
|
||||
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
|
||||
("MarianConfig", "MarianMTModel"),
|
||||
("FSMTConfig", "FSMTForConditionalGeneration"),
|
||||
("BartConfig", "BartForConditionalGeneration"),
|
||||
("LongformerConfig", "LongformerForMaskedLM"),
|
||||
("RobertaConfig", "RobertaForMaskedLM"),
|
||||
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
|
||||
("BertConfig", "BertForMaskedLM"),
|
||||
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
|
||||
("GPT2Config", "GPT2LMHeadModel"),
|
||||
("MegatronBertConfig", "MegatronBertForCausalLM"),
|
||||
("MobileBertConfig", "MobileBertForMaskedLM"),
|
||||
("TransfoXLConfig", "TransfoXLLMHeadModel"),
|
||||
("XLNetConfig", "XLNetLMHeadModel"),
|
||||
("FlaubertConfig", "FlaubertWithLMHeadModel"),
|
||||
("XLMConfig", "XLMWithLMHeadModel"),
|
||||
("CTRLConfig", "CTRLLMHeadModel"),
|
||||
("ElectraConfig", "ElectraForMaskedLM"),
|
||||
("EncoderDecoderConfig", "EncoderDecoderModel"),
|
||||
("ReformerConfig", "ReformerModelWithLMHead"),
|
||||
("FunnelConfig", "FunnelForMaskedLM"),
|
||||
("MPNetConfig", "MPNetForMaskedLM"),
|
||||
("TapasConfig", "TapasForMaskedLM"),
|
||||
("DebertaConfig", "DebertaForMaskedLM"),
|
||||
("DebertaV2Config", "DebertaV2ForMaskedLM"),
|
||||
("IBertConfig", "IBertForMaskedLM"),
|
||||
]
|
||||
)
|
@ -17,7 +17,7 @@
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -1484,7 +1484,7 @@ from ...file_utils import (
|
||||
)
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPast,
|
||||
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||
TFSeq2SeqLMOutput,
|
||||
TFSeq2SeqModelOutput,
|
||||
)
|
||||
@ -2162,7 +2162,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
@ -172,17 +172,12 @@
|
||||
# To replace in: "src/transformers/models/auto/configuration_auto.py"
|
||||
# Below: "# Add configs here"
|
||||
# Replace with:
|
||||
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Config"),
|
||||
# End.
|
||||
|
||||
# Below: "# Add archive maps here"
|
||||
# Replace with:
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
# End.
|
||||
|
||||
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
|
||||
# Replace with:
|
||||
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
# End.
|
||||
|
||||
# Below: "# Add full (and cased) model names here"
|
||||
@ -193,75 +188,47 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca
|
||||
|
||||
|
||||
# To replace in: "src/transformers/models/auto/modeling_auto.py" if generating PyTorch
|
||||
# Below: "from .configuration_auto import ("
|
||||
# Replace with:
|
||||
{{cookiecutter.camelcase_modelname}}Config,
|
||||
# End.
|
||||
|
||||
# Below: "# Add modeling imports here"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
{% else -%}
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "# Base model mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Model),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Model"),
|
||||
# End.
|
||||
|
||||
# Below: "# Model with LM heads mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
|
||||
{% else %}
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Causal LM mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForCausalLM"),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Masked LM mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Sequence Classification mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForSequenceClassification),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Question Answering mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Token Classification mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForTokenClassification),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -269,7 +236,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
|
||||
# Below: "# Model for Multiple Choice mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMultipleChoice),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -278,54 +245,29 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
{% else %}
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
|
||||
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# To replace in: "src/transformers/models/auto/modeling_tf_auto.py" if generating TensorFlow
|
||||
# Below: "from .configuration_auto import ("
|
||||
# Replace with:
|
||||
{{cookiecutter.camelcase_modelname}}Config,
|
||||
# End.
|
||||
|
||||
# Below: "# Add modeling imports here"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
{% else -%}
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "# Base model mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}Model),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}Model"),
|
||||
# End.
|
||||
|
||||
# Below: "# Model with LM heads mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
|
||||
{% else %}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Causal LM mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForCausalLM),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForCausalLM"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -333,7 +275,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
||||
# Below: "# Model for Masked LM mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -341,7 +283,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
||||
# Below: "# Model for Sequence Classification mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -349,7 +291,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
||||
# Below: "# Model for Question Answering mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -357,7 +299,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
||||
# Below: "# Model for Token Classification mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForTokenClassification),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -365,7 +307,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
||||
# Below: "# Model for Multiple Choice mapping"
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
|
||||
{% else -%}
|
||||
{% endif -%}
|
||||
# End.
|
||||
@ -374,7 +316,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
||||
{% else %}
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
|
||||
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
|
@ -23,7 +23,8 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||
from transformers.models.mbart import MBartForConditionalGeneration
|
||||
from transformers.models.mbart50 import MBart50TokenizerFast
|
||||
|
||||
|
||||
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
|
@ -306,17 +306,17 @@ def get_all_auto_configured_models():
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
if is_tf_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
if is_flax_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls.__name__ for cls in result]
|
||||
return [cls for cls in result]
|
||||
|
||||
|
||||
def ignore_unautoclassed(model_name):
|
||||
|
@ -87,12 +87,13 @@ def get_model_table_from_auto_modules():
|
||||
transformers = spec.loader.load_module()
|
||||
|
||||
# Dictionary model names to config.
|
||||
config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_name_to_config = {
|
||||
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||
}
|
||||
model_name_to_prefix = {
|
||||
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
|
||||
name: config_maping_names[code]
|
||||
for code, name in transformers.MODEL_NAMES_MAPPING.items()
|
||||
if code in config_maping_names
|
||||
}
|
||||
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
|
||||
|
||||
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
|
||||
slow_tokenizers = collections.defaultdict(bool)
|
||||
|
@ -1,106 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# this script remaps classes to class strings so that it's quick to load such maps and not require
|
||||
# loading all possible modeling files
|
||||
#
|
||||
# it can be extended to auto-generate other dicts that are needed at runtime
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
|
||||
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
|
||||
sys.path.insert(1, git_repo_path)
|
||||
|
||||
src = "src/transformers/models/auto/modeling_auto.py"
|
||||
dst = "src/transformers/utils/modeling_auto_mapping.py"
|
||||
|
||||
|
||||
if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst):
|
||||
# speed things up by only running this script if the src is newer than dst
|
||||
sys.exit(0)
|
||||
|
||||
# only load if needed
|
||||
from transformers.models.auto.modeling_auto import ( # noqa
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
# Those constants don't have a name attribute, so we need to define it manually
|
||||
mappings = {
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
"MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING,
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
"MODEL_MAPPING": MODEL_MAPPING,
|
||||
"MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING,
|
||||
}
|
||||
|
||||
|
||||
def get_name(value):
|
||||
if isinstance(value, tuple):
|
||||
return tuple(get_name(o) for o in value)
|
||||
return value.__name__
|
||||
|
||||
|
||||
content = [
|
||||
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
|
||||
"# 1. modify: models/auto/modeling_auto.py",
|
||||
"# 2. run: python utils/class_mapping_update.py",
|
||||
"from collections import OrderedDict",
|
||||
"",
|
||||
]
|
||||
|
||||
for name, mapping in mappings.items():
|
||||
entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()])
|
||||
|
||||
content += [
|
||||
"",
|
||||
f"{name}_NAMES = OrderedDict(",
|
||||
" [",
|
||||
entries,
|
||||
" ]",
|
||||
")",
|
||||
"",
|
||||
]
|
||||
|
||||
print(f"Updating {dst}")
|
||||
with open(dst, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write("\n".join(content))
|
Loading…
Reference in New Issue
Block a user