[WIP] Disentangle auto modules from other modeling files (#13023)

* Initial work

* All auto models

* All tf auto models

* All flax auto models

* Tokenizers

* Add feature extractors

* Fix typos

* Fix other typo

* Use the right config

* Remove old mapping names and update logic in AutoTokenizer

* Update check_table

* Fix copies and check_repo script

* Fix last test

* Add back name

* clean up

* Update template

* Update template

* Forgot a )

* Use alternative to fixup

* Fix TF model template

* Address review comments

* Address review comments

* Style
This commit is contained in:
Sylvain Gugger 2021-08-06 13:12:30 +02:00 committed by GitHub
parent 2e4082364e
commit 9870093f7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1338 additions and 2405 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Run style changes
run: |
git fetch origin master:master
make fixup
make style && make quality
- name: Failure short reports
if: ${{ always() }}

View File

@ -30,7 +30,6 @@ deps_table_check_updated:
# autogenerating code
autogenerate_code: deps_table_update
python utils/class_mapping_update.py
# Check that source code meets quality standards

View File

@ -213,6 +213,7 @@ _import_structure = {
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
@ -315,7 +316,7 @@ if is_sentencepiece_available():
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer")
_import_structure["models.mbart50"].append("MBart50Tokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
@ -358,7 +359,7 @@ if is_tokenizers_available():
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart"].append("MBart50TokenizerFast")
_import_structure["models.mbart50"].append("MBart50TokenizerFast")
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast")
@ -2021,7 +2022,8 @@ if TYPE_CHECKING:
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
from .models.mbart import MBartTokenizerFast
from .models.mbart50 import MBart50TokenizerFast
from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast

View File

@ -41,9 +41,7 @@ from .file_utils import (
is_tokenizers_available,
is_torch_available,
)
from .training_args import ParallelMode
from .utils import logging
from .utils.modeling_auto_mapping import (
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
@ -54,6 +52,8 @@ from .utils.modeling_auto_mapping import (
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import logging
TASK_MAPPING = {

View File

@ -37,6 +37,7 @@ from . import (
cpm,
ctrl,
deberta,
deberta_v2,
deit,
detr,
dialogpt,
@ -50,6 +51,8 @@ from . import (
gpt2,
gpt_neo,
herbert,
hubert,
ibert,
layoutlm,
led,
longformer,
@ -58,6 +61,7 @@ from . import (
m2m_100,
marian,
mbart,
mbart50,
megatron_bert,
mmbt,
mobilebert,
@ -82,6 +86,7 @@ from . import (
vit,
wav2vec2,
xlm,
xlm_prophetnet,
xlm_roberta,
xlnet,
)

View File

@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
logger = logging.get_logger(__name__)
@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
cls.from_config = classmethod(from_config)
if name.startswith("TF"):
@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls
@ -445,3 +447,79 @@ def get_values(model_mapping):
result.append(model)
return result
def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)
class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args:
- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""
def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}
def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]
def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]
def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]
def __iter__(self):
return iter(self._mapping.keys())
def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping

View File

@ -13,215 +13,140 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Config class. """
import importlib
import re
import warnings
from collections import OrderedDict
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from ..beit.configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig
from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from ..bert_generation.configuration_bert_generation import BertGenerationConfig
from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
from ..bigbird_pegasus.configuration_bigbird_pegasus import (
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
BigBirdPegasusConfig,
)
from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from ..blenderbot_small.configuration_blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BlenderbotSmallConfig,
)
from ..camembert.configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from ..canine.configuration_canine import CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, CanineConfig
from ..clip.configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig
from ..convbert.configuration_convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig
from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig
from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from ..encoder_decoder.configuration_encoder_decoder import EncoderDecoderConfig
from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
from ..luke.configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from ..megatron_bert.configuration_megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
from ..mobilebert.configuration_mobilebert import MobileBertConfig
from ..mpnet.configuration_mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig
from ..mt5.configuration_mt5 import MT5Config
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from ..pegasus.configuration_pegasus import PegasusConfig
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
from ..rag.configuration_rag import RagConfig
from ..reformer.configuration_reformer import ReformerConfig
from ..rembert.configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from ..roformer.configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
from ..speech_to_text.configuration_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
)
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from ..visual_bert.configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLMProphetNetConfig,
)
from ..xlm_roberta.configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from ..xlnet.configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
# Add archive maps here
BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP,
DETR_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP,
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP,
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
CONFIG_MAPPING = OrderedDict(
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("beit", BeitConfig),
("rembert", RemBertConfig),
("visual_bert", VisualBertConfig),
("canine", CanineConfig),
("roformer", RoFormerConfig),
("clip", CLIPConfig),
("bigbird_pegasus", BigBirdPegasusConfig),
("deit", DeiTConfig),
("luke", LukeConfig),
("detr", DetrConfig),
("gpt_neo", GPTNeoConfig),
("big_bird", BigBirdConfig),
("speech_to_text", Speech2TextConfig),
("vit", ViTConfig),
("wav2vec2", Wav2Vec2Config),
("m2m_100", M2M100Config),
("convbert", ConvBertConfig),
("led", LEDConfig),
("blenderbot-small", BlenderbotSmallConfig),
("retribert", RetriBertConfig),
("ibert", IBertConfig),
("mt5", MT5Config),
("t5", T5Config),
("mobilebert", MobileBertConfig),
("distilbert", DistilBertConfig),
("albert", AlbertConfig),
("bert-generation", BertGenerationConfig),
("camembert", CamembertConfig),
("xlm-roberta", XLMRobertaConfig),
("pegasus", PegasusConfig),
("marian", MarianConfig),
("mbart", MBartConfig),
("megatron-bert", MegatronBertConfig),
("mpnet", MPNetConfig),
("bart", BartConfig),
("blenderbot", BlenderbotConfig),
("reformer", ReformerConfig),
("longformer", LongformerConfig),
("roberta", RobertaConfig),
("deberta-v2", DebertaV2Config),
("deberta", DebertaConfig),
("flaubert", FlaubertConfig),
("fsmt", FSMTConfig),
("squeezebert", SqueezeBertConfig),
("hubert", HubertConfig),
("bert", BertConfig),
("openai-gpt", OpenAIGPTConfig),
("gpt2", GPT2Config),
("transfo-xl", TransfoXLConfig),
("xlnet", XLNetConfig),
("xlm-prophetnet", XLMProphetNetConfig),
("prophetnet", ProphetNetConfig),
("xlm", XLMConfig),
("ctrl", CTRLConfig),
("electra", ElectraConfig),
("encoder-decoder", EncoderDecoderConfig),
("funnel", FunnelConfig),
("lxmert", LxmertConfig),
("dpr", DPRConfig),
("layoutlm", LayoutLMConfig),
("rag", RagConfig),
("tapas", TapasConfig),
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
("visual_bert", "VisualBertConfig"),
("canine", "CanineConfig"),
("roformer", "RoFormerConfig"),
("clip", "CLIPConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("deit", "DeiTConfig"),
("luke", "LukeConfig"),
("detr", "DetrConfig"),
("gpt_neo", "GPTNeoConfig"),
("big_bird", "BigBirdConfig"),
("speech_to_text", "Speech2TextConfig"),
("vit", "ViTConfig"),
("wav2vec2", "Wav2Vec2Config"),
("m2m_100", "M2M100Config"),
("convbert", "ConvBertConfig"),
("led", "LEDConfig"),
("blenderbot-small", "BlenderbotSmallConfig"),
("retribert", "RetriBertConfig"),
("ibert", "IBertConfig"),
("mt5", "MT5Config"),
("t5", "T5Config"),
("mobilebert", "MobileBertConfig"),
("distilbert", "DistilBertConfig"),
("albert", "AlbertConfig"),
("bert-generation", "BertGenerationConfig"),
("camembert", "CamembertConfig"),
("xlm-roberta", "XLMRobertaConfig"),
("pegasus", "PegasusConfig"),
("marian", "MarianConfig"),
("mbart", "MBartConfig"),
("megatron-bert", "MegatronBertConfig"),
("mpnet", "MPNetConfig"),
("bart", "BartConfig"),
("blenderbot", "BlenderbotConfig"),
("reformer", "ReformerConfig"),
("longformer", "LongformerConfig"),
("roberta", "RobertaConfig"),
("deberta-v2", "DebertaV2Config"),
("deberta", "DebertaConfig"),
("flaubert", "FlaubertConfig"),
("fsmt", "FSMTConfig"),
("squeezebert", "SqueezeBertConfig"),
("hubert", "HubertConfig"),
("bert", "BertConfig"),
("openai-gpt", "OpenAIGPTConfig"),
("gpt2", "GPT2Config"),
("transfo-xl", "TransfoXLConfig"),
("xlnet", "XLNetConfig"),
("xlm-prophetnet", "XLMProphetNetConfig"),
("prophetnet", "ProphetNetConfig"),
("xlm", "XLMConfig"),
("ctrl", "CTRLConfig"),
("electra", "ElectraConfig"),
("encoder-decoder", "EncoderDecoderConfig"),
("funnel", "FunnelConfig"),
("lxmert", "LxmertConfig"),
("dpr", "DPRConfig"),
("layoutlm", "LayoutLMConfig"),
("rag", "RagConfig"),
("tapas", "TapasConfig"),
]
)
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt_neo", "GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("big_bird", "BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("speech_to_text", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot-small", "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bert", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bart", "BART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("blenderbot", "BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mbart", "MBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gpt2", "GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlnet", "XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm", "XLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roberta", "ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("albert", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("camembert", "CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-roberta", "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("flaubert", "FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("fsmt", "FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("retribert", "RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("funnel", "FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deberta-v2", "DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xlm-prophetnet", "XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("prophetnet", "PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mpnet", "MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
]
)
@ -290,14 +215,136 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mpnet", "MPNet"),
("tapas", "TAPAS"),
("hubert", "Hubert"),
("barthez", "BARThez"),
("phobert", "PhoBERT"),
("cpm", "CPM"),
("bertweet", "Bertweet"),
("bert-japanese", "BertJapanese"),
("byt5", "ByT5"),
("mbart50", "mBART-50"),
]
)
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict([("openai-gpt", "openai")])
def _get_class_name(model_class):
def model_type_to_module_name(key):
"""Converts a config key to the corresponding module."""
# Special treatment
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
return SPECIAL_MODEL_TYPE_TO_MODULE_NAME[key]
return key.replace("-", "_")
def config_class_to_model_type(config):
"""Converts a config class name to the corresponding model type"""
for key, cls in CONFIG_MAPPING_NAMES.items():
if cls == config:
return key
return None
class _LazyConfigMapping(OrderedDict):
"""
A dictionary that lazily load its values when they are requested.
"""
def __init__(self, mapping):
self._mapping = mapping
self._modules = {}
def __getitem__(self, key):
if key not in self._mapping:
raise KeyError(key)
value = self._mapping[key]
module_name = model_type_to_module_name(key)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(self._modules[module_name], value)
def keys(self):
return self._mapping.keys()
def values(self):
return [self[k] for k in self._mapping.keys()]
def items(self):
return [(k, self[k]) for k in self._mapping.keys()]
def __iter__(self):
return iter(self._mapping.keys())
def __contains__(self, item):
return item in self._mapping
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
class _LazyLoadAllMappings(OrderedDict):
"""
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
etc.)
Args:
mapping: The mapping to load.
"""
def __init__(self, mapping):
self._mapping = mapping
self._initialized = False
self._data = {}
def _initialize(self):
if self._initialized:
return
warnings.warn(
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
FutureWarning,
)
for model_type, map_name in self._mapping.items():
module_name = model_type_to_module_name(model_type)
module = importlib.import_module(f".{module_name}", "transformers.models")
mapping = getattr(module, map_name)
self._data.update(mapping)
self._initialized = True
def __getitem__(self, key):
self._initialize()
return self._data[key]
def keys(self):
self._initialize()
return self._data.keys()
def values(self):
self._initialize()
return self._data.values()
def items(self):
self._initialize()
return self._data.keys()
def __iter__(self):
self._initialize()
return iter(self._data)
def __contains__(self, item):
self._initialize()
return item in self._data
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPPING_NAMES)
def _get_class_name(model_class: Union[str, List[str]]):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
return f":class:`~transformers.{model_class.__name__}`"
return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None])
return f":class:`~transformers.{model_class}`"
def _list_model_options(indent, config_to_class=None, use_model_types=True):
@ -306,23 +353,26 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
if use_model_types:
if config_to_class is None:
model_type_to_name = {
model_type: f":class:`~transformers.{config.__name__}`"
for model_type, config in CONFIG_MAPPING.items()
model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
}
else:
model_type_to_name = {
model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class
model_type: _get_class_name(model_class)
for model_type, model_class in config_to_class.items()
if model_type in MODEL_NAMES_MAPPING
}
lines = [
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys())
]
else:
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_name = {
CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
for config, clas in config_to_class.items()
if config in CONFIG_MAPPING_NAMES
}
config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
}
lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"

View File

@ -13,36 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" AutoFeatureExtractor class. """
import importlib
import os
from collections import OrderedDict
from transformers import BeitFeatureExtractor, DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor
from ... import BeitConfig, DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config
from ...feature_extraction_utils import FeatureExtractionMixin
# Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import FEATURE_EXTRACTOR_NAME
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
FEATURE_EXTRACTOR_MAPPING = OrderedDict(
[
(BeitConfig, BeitFeatureExtractor),
(DeiTConfig, DeiTFeatureExtractor),
(Speech2TextConfig, Speech2TextFeatureExtractor),
(ViTConfig, ViTFeatureExtractor),
(Wav2Vec2Config, Wav2Vec2FeatureExtractor),
]
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
replace_list_option_in_docstrings,
)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
]
)
FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRACTOR_MAPPING_NAMES)
def feature_extractor_class_from_name(class_name: str):
for c in FEATURE_EXTRACTOR_MAPPING.values():
if c is not None and c.__name__ == class_name:
return c
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors:
break
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
class AutoFeatureExtractor:
@ -60,7 +67,7 @@ class AutoFeatureExtractor:
)
@classmethod
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING)
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
@ -142,7 +149,8 @@ class AutoFeatureExtractor:
kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if type(config) in FEATURE_EXTRACTOR_MAPPING.keys():
model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
elif "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])

File diff suppressed because it is too large Load Diff

View File

@ -18,207 +18,163 @@
from collections import OrderedDict
from ...utils import logging
from ..bart.modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from ..bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
)
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
FlaxElectraForQuestionAnswering,
FlaxElectraForSequenceClassification,
FlaxElectraForTokenClassification,
FlaxElectraModel,
)
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel
from ..mbart.modeling_flax_mbart import (
FlaxMBartForConditionalGeneration,
FlaxMBartForQuestionAnswering,
FlaxMBartForSequenceClassification,
FlaxMBartModel,
)
from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
BartConfig,
BertConfig,
BigBirdConfig,
CLIPConfig,
ElectraConfig,
GPT2Config,
GPTNeoConfig,
MarianConfig,
MBartConfig,
MT5Config,
RobertaConfig,
T5Config,
ViTConfig,
Wav2Vec2Config,
)
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING = OrderedDict(
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
(BigBirdConfig, FlaxBigBirdModel),
(BartConfig, FlaxBartModel),
(GPT2Config, FlaxGPT2Model),
(GPTNeoConfig, FlaxGPTNeoModel),
(ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel),
(MBartConfig, FlaxMBartModel),
(T5Config, FlaxT5Model),
(MT5Config, FlaxMT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model),
(MarianConfig, FlaxMarianModel),
("roberta", "FlaxRobertaModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
("bart", "FlaxBartModel"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("electra", "FlaxElectraModel"),
("clip", "FlaxCLIPModel"),
("vit", "FlaxViTModel"),
("mbart", "FlaxMBartModel"),
("t5", "FlaxT5Model"),
("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
]
)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
(BigBirdConfig, FlaxBigBirdForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining),
(MBartConfig, FlaxMBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxMT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForPreTraining"),
("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
]
)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
(BigBirdConfig, FlaxBigBirdForMaskedLM),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForMaskedLM),
(MBartConfig, FlaxMBartForConditionalGeneration),
("roberta", "FlaxRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxMT5ForConditionalGeneration),
(MarianConfig, FlaxMarianMTModel),
("bart", "FlaxBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
(ViTConfig, FlaxViTForImageClassification),
("vit", "FlaxViTForImageClassification"),
]
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
(GPT2Config, FlaxGPT2LMHeadModel),
(GPTNeoConfig, FlaxGPTNeoForCausalLM),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
(BigBirdConfig, FlaxBigBirdForSequenceClassification),
(BartConfig, FlaxBartForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification),
(MBartConfig, FlaxMBartForSequenceClassification),
("roberta", "FlaxRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
]
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
(BigBirdConfig, FlaxBigBirdForQuestionAnswering),
(BartConfig, FlaxBartForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering),
(MBartConfig, FlaxMBartForQuestionAnswering),
("roberta", "FlaxRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
]
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
(BigBirdConfig, FlaxBigBirdForTokenClassification),
(ElectraConfig, FlaxElectraForTokenClassification),
("roberta", "FlaxRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
]
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
(BigBirdConfig, FlaxBigBirdForMultipleChoice),
(ElectraConfig, FlaxElectraForMultipleChoice),
("roberta", "FlaxRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
]
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
(BertConfig, FlaxBertForNextSentencePrediction),
("bert", "FlaxBertForNextSentencePrediction"),
]
)
FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
class FlaxAutoModel(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_MAPPING

View File

@ -19,492 +19,298 @@ import warnings
from collections import OrderedDict
from ...utils import logging
# Add modeling imports here
from ..albert.modeling_tf_albert import (
TFAlbertForMaskedLM,
TFAlbertForMultipleChoice,
TFAlbertForPreTraining,
TFAlbertForQuestionAnswering,
TFAlbertForSequenceClassification,
TFAlbertForTokenClassification,
TFAlbertModel,
)
from ..bart.modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from ..bert.modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
TFBertForNextSentencePrediction,
TFBertForPreTraining,
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertForTokenClassification,
TFBertLMHeadModel,
TFBertModel,
)
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from ..blenderbot_small.modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
)
from ..camembert.modeling_tf_camembert import (
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
TFCamembertForQuestionAnswering,
TFCamembertForSequenceClassification,
TFCamembertForTokenClassification,
TFCamembertModel,
)
from ..convbert.modeling_tf_convbert import (
TFConvBertForMaskedLM,
TFConvBertForMultipleChoice,
TFConvBertForQuestionAnswering,
TFConvBertForSequenceClassification,
TFConvBertForTokenClassification,
TFConvBertModel,
)
from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel
from ..distilbert.modeling_tf_distilbert import (
TFDistilBertForMaskedLM,
TFDistilBertForMultipleChoice,
TFDistilBertForQuestionAnswering,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertModel,
)
from ..dpr.modeling_tf_dpr import TFDPRQuestionEncoder
from ..electra.modeling_tf_electra import (
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
)
from ..flaubert.modeling_tf_flaubert import (
TFFlaubertForMultipleChoice,
TFFlaubertForQuestionAnsweringSimple,
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM,
TFFunnelForMultipleChoice,
TFFunnelForPreTraining,
TFFunnelForQuestionAnswering,
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..hubert.modeling_tf_hubert import TFHubertModel
from ..layoutlm.modeling_tf_layoutlm import (
TFLayoutLMForMaskedLM,
TFLayoutLMForSequenceClassification,
TFLayoutLMForTokenClassification,
TFLayoutLMModel,
)
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
)
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
TFMobileBertForPreTraining,
TFMobileBertForQuestionAnswering,
TFMobileBertForSequenceClassification,
TFMobileBertForTokenClassification,
TFMobileBertModel,
)
from ..mpnet.modeling_tf_mpnet import (
TFMPNetForMaskedLM,
TFMPNetForMultipleChoice,
TFMPNetForQuestionAnswering,
TFMPNetForSequenceClassification,
TFMPNetForTokenClassification,
TFMPNetModel,
)
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from ..rembert.modeling_tf_rembert import (
TFRemBertForCausalLM,
TFRemBertForMaskedLM,
TFRemBertForMultipleChoice,
TFRemBertForQuestionAnswering,
TFRemBertForSequenceClassification,
TFRemBertForTokenClassification,
TFRemBertModel,
)
from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaModel,
)
from ..roformer.modeling_tf_roformer import (
TFRoFormerForCausalLM,
TFRoFormerForMaskedLM,
TFRoFormerForMultipleChoice,
TFRoFormerForQuestionAnswering,
TFRoFormerForSequenceClassification,
TFRoFormerForTokenClassification,
TFRoFormerModel,
)
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from ..transfo_xl.modeling_tf_transfo_xl import (
TFTransfoXLForSequenceClassification,
TFTransfoXLLMHeadModel,
TFTransfoXLModel,
)
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
TFXLMForSequenceClassification,
TFXLMForTokenClassification,
TFXLMModel,
TFXLMWithLMHeadModel,
)
from ..xlm_roberta.modeling_tf_xlm_roberta import (
TFXLMRobertaForMaskedLM,
TFXLMRobertaForMultipleChoice,
TFXLMRobertaForQuestionAnswering,
TFXLMRobertaForSequenceClassification,
TFXLMRobertaForTokenClassification,
TFXLMRobertaModel,
)
from ..xlnet.modeling_tf_xlnet import (
TFXLNetForMultipleChoice,
TFXLNetForQuestionAnsweringSimple,
TFXLNetForSequenceClassification,
TFXLNetForTokenClassification,
TFXLNetLMHeadModel,
TFXLNetModel,
)
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
AlbertConfig,
BartConfig,
BertConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
ConvBertConfig,
CTRLConfig,
DistilBertConfig,
DPRConfig,
ElectraConfig,
FlaubertConfig,
FunnelConfig,
GPT2Config,
HubertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LxmertConfig,
MarianConfig,
MBartConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
RemBertConfig,
RobertaConfig,
RoFormerConfig,
T5Config,
TransfoXLConfig,
Wav2Vec2Config,
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
)
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
(RemBertConfig, TFRemBertModel),
(RoFormerConfig, TFRoFormerModel),
(ConvBertConfig, TFConvBertModel),
(LEDConfig, TFLEDModel),
(LxmertConfig, TFLxmertModel),
(MT5Config, TFMT5Model),
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
(BartConfig, TFBartModel),
(CamembertConfig, TFCamembertModel),
(XLMRobertaConfig, TFXLMRobertaModel),
(LongformerConfig, TFLongformerModel),
(RobertaConfig, TFRobertaModel),
(LayoutLMConfig, TFLayoutLMModel),
(BertConfig, TFBertModel),
(OpenAIGPTConfig, TFOpenAIGPTModel),
(GPT2Config, TFGPT2Model),
(MobileBertConfig, TFMobileBertModel),
(TransfoXLConfig, TFTransfoXLModel),
(XLNetConfig, TFXLNetModel),
(FlaubertConfig, TFFlaubertModel),
(XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel),
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
(MBartConfig, TFMBartModel),
(MarianConfig, TFMarianModel),
(PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model),
(HubertConfig, TFHubertModel),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
("t5", "TFT5Model"),
("distilbert", "TFDistilBertModel"),
("albert", "TFAlbertModel"),
("bart", "TFBartModel"),
("camembert", "TFCamembertModel"),
("xlm-roberta", "TFXLMRobertaModel"),
("longformer", "TFLongformerModel"),
("roberta", "TFRobertaModel"),
("layoutlm", "TFLayoutLMModel"),
("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"),
("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"),
("flaubert", "TFFlaubertModel"),
("xlm", "TFXLMModel"),
("ctrl", "TFCTRLModel"),
("electra", "TFElectraModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("dpr", "TFDPRQuestionEncoder"),
("mpnet", "TFMPNetModel"),
("mbart", "TFMBartModel"),
("marian", "TFMarianModel"),
("pegasus", "TFPegasusModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"),
]
)
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
(LxmertConfig, TFLxmertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForPreTraining),
(BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForPreTraining),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(MobileBertConfig, TFMobileBertForPreTraining),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CTRLConfig, TFCTRLLMHeadModel),
(ElectraConfig, TFElectraForPreTraining),
(FunnelConfig, TFFunnelForPreTraining),
(MPNetConfig, TFMPNetForMaskedLM),
("lxmert", "TFLxmertForPreTraining"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForPreTraining"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForPreTraining"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForPreTraining"),
("funnel", "TFFunnelForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
(RemBertConfig, TFRemBertForMaskedLM),
(RoFormerConfig, TFRoFormerForMaskedLM),
(ConvBertConfig, TFConvBertForMaskedLM),
(LEDConfig, TFLEDForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(MarianConfig, TFMarianMTModel),
(BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(MobileBertConfig, TFMobileBertForMaskedLM),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(CTRLConfig, TFCTRLLMHeadModel),
(ElectraConfig, TFElectraForMaskedLM),
(FunnelConfig, TFFunnelForMaskedLM),
(MPNetConfig, TFMPNetForMaskedLM),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("marian", "TFMarianMTModel"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
(RemBertConfig, TFRemBertForCausalLM),
(RoFormerConfig, TFRoFormerForCausalLM),
(BertConfig, TFBertLMHeadModel),
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
(GPT2Config, TFGPT2LMHeadModel),
(TransfoXLConfig, TFTransfoXLLMHeadModel),
(XLNetConfig, TFXLNetLMHeadModel),
(
XLMConfig,
TFXLMWithLMHeadModel,
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
(CTRLConfig, TFCTRLLMHeadModel),
("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
(RemBertConfig, TFRemBertForMaskedLM),
(RoFormerConfig, TFRoFormerForMaskedLM),
(ConvBertConfig, TFConvBertForMaskedLM),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM),
(CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
(LongformerConfig, TFLongformerForMaskedLM),
(RobertaConfig, TFRobertaForMaskedLM),
(LayoutLMConfig, TFLayoutLMForMaskedLM),
(BertConfig, TFBertForMaskedLM),
(MobileBertConfig, TFMobileBertForMaskedLM),
(FlaubertConfig, TFFlaubertWithLMHeadModel),
(XLMConfig, TFXLMWithLMHeadModel),
(ElectraConfig, TFElectraForMaskedLM),
(FunnelConfig, TFFunnelForMaskedLM),
(MPNetConfig, TFMPNetForMaskedLM),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(LEDConfig, TFLEDForConditionalGeneration),
(MT5Config, TFMT5ForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(MarianConfig, TFMarianMTModel),
(MBartConfig, TFMBartForConditionalGeneration),
(PegasusConfig, TFPegasusForConditionalGeneration),
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
(BartConfig, TFBartForConditionalGeneration),
("led", "TFLEDForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
(RemBertConfig, TFRemBertForSequenceClassification),
(RoFormerConfig, TFRoFormerForSequenceClassification),
(ConvBertConfig, TFConvBertForSequenceClassification),
(DistilBertConfig, TFDistilBertForSequenceClassification),
(AlbertConfig, TFAlbertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(LongformerConfig, TFLongformerForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(LayoutLMConfig, TFLayoutLMForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
(MobileBertConfig, TFMobileBertForSequenceClassification),
(FlaubertConfig, TFFlaubertForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
(ElectraConfig, TFElectraForSequenceClassification),
(FunnelConfig, TFFunnelForSequenceClassification),
(GPT2Config, TFGPT2ForSequenceClassification),
(MPNetConfig, TFMPNetForSequenceClassification),
(OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification),
(TransfoXLConfig, TFTransfoXLForSequenceClassification),
(CTRLConfig, TFCTRLForSequenceClassification),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("albert", "TFAlbertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("bert", "TFBertForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
(RemBertConfig, TFRemBertForQuestionAnswering),
(RoFormerConfig, TFRoFormerForQuestionAnswering),
(ConvBertConfig, TFConvBertForQuestionAnswering),
(DistilBertConfig, TFDistilBertForQuestionAnswering),
(AlbertConfig, TFAlbertForQuestionAnswering),
(CamembertConfig, TFCamembertForQuestionAnswering),
(XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
(LongformerConfig, TFLongformerForQuestionAnswering),
(RobertaConfig, TFRobertaForQuestionAnswering),
(BertConfig, TFBertForQuestionAnswering),
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
(MobileBertConfig, TFMobileBertForQuestionAnswering),
(FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
(XLMConfig, TFXLMForQuestionAnsweringSimple),
(ElectraConfig, TFElectraForQuestionAnswering),
(FunnelConfig, TFFunnelForQuestionAnswering),
(MPNetConfig, TFMPNetForQuestionAnswering),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("albert", "TFAlbertForQuestionAnswering"),
("camembert", "TFCamembertForQuestionAnswering"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("bert", "TFBertForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("electra", "TFElectraForQuestionAnswering"),
("funnel", "TFFunnelForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
(RemBertConfig, TFRemBertForTokenClassification),
(RoFormerConfig, TFRoFormerForTokenClassification),
(ConvBertConfig, TFConvBertForTokenClassification),
(DistilBertConfig, TFDistilBertForTokenClassification),
(AlbertConfig, TFAlbertForTokenClassification),
(CamembertConfig, TFCamembertForTokenClassification),
(FlaubertConfig, TFFlaubertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(LongformerConfig, TFLongformerForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(LayoutLMConfig, TFLayoutLMForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification),
(XLNetConfig, TFXLNetForTokenClassification),
(ElectraConfig, TFElectraForTokenClassification),
(FunnelConfig, TFFunnelForTokenClassification),
(MPNetConfig, TFMPNetForTokenClassification),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("albert", "TFAlbertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
(RemBertConfig, TFRemBertForMultipleChoice),
(RoFormerConfig, TFRoFormerForMultipleChoice),
(ConvBertConfig, TFConvBertForMultipleChoice),
(CamembertConfig, TFCamembertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(LongformerConfig, TFLongformerForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice),
(MobileBertConfig, TFMobileBertForMultipleChoice),
(XLNetConfig, TFXLNetForMultipleChoice),
(FlaubertConfig, TFFlaubertForMultipleChoice),
(AlbertConfig, TFAlbertForMultipleChoice),
(ElectraConfig, TFElectraForMultipleChoice),
(FunnelConfig, TFFunnelForMultipleChoice),
(MPNetConfig, TFMPNetForMultipleChoice),
("rembert", "TFRemBertForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("albert", "TFAlbertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
]
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
(BertConfig, TFBertForNextSentencePrediction),
(MobileBertConfig, TFMobileBertForNextSentencePrediction),
("bert", "TFBertForNextSentencePrediction"),
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING

View File

@ -14,12 +14,12 @@
# limitations under the License.
""" Auto Tokenizer class. """
import importlib
import json
import os
from collections import OrderedDict
from typing import Dict, Optional, Union
from ... import GPTNeoConfig
from ...configuration_utils import PretrainedConfig
from ...file_utils import (
cached_path,
@ -29,315 +29,183 @@ from ...file_utils import (
is_tokenizers_available,
)
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from ..bart.tokenization_bart import BartTokenizer
from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..byt5.tokenization_byt5 import ByT5Tokenizer
from ..canine.tokenization_canine import CanineTokenizer
from ..convbert.tokenization_convbert import ConvBertTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer
from ..distilbert.tokenization_distilbert import DistilBertTokenizer
from ..dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
from ..electra.tokenization_electra import ElectraTokenizer
from ..flaubert.tokenization_flaubert import FlaubertTokenizer
from ..fsmt.tokenization_fsmt import FSMTTokenizer
from ..funnel.tokenization_funnel import FunnelTokenizer
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
from ..herbert.tokenization_herbert import HerbertTokenizer
from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer
from ..led.tokenization_led import LEDTokenizer
from ..longformer.tokenization_longformer import LongformerTokenizer
from ..luke.tokenization_luke import LukeTokenizer
from ..lxmert.tokenization_lxmert import LxmertTokenizer
from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer
from ..mpnet.tokenization_mpnet import MPNetTokenizer
from ..openai.tokenization_openai import OpenAIGPTTokenizer
from ..phobert.tokenization_phobert import PhobertTokenizer
from ..prophetnet.tokenization_prophetnet import ProphetNetTokenizer
from ..rag.tokenization_rag import RagTokenizer
from ..retribert.tokenization_retribert import RetriBertTokenizer
from ..roberta.tokenization_roberta import RobertaTokenizer
from ..roformer.tokenization_roformer import RoFormerTokenizer
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
from ..tapas.tokenization_tapas import TapasTokenizer
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
from ..xlm.tokenization_xlm import XLMTokenizer
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
AlbertConfig,
CONFIG_MAPPING_NAMES,
AutoConfig,
BartConfig,
BertConfig,
BertGenerationConfig,
BigBirdConfig,
BigBirdPegasusConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CanineConfig,
ConvBertConfig,
CTRLConfig,
DebertaConfig,
DebertaV2Config,
DistilBertConfig,
DPRConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
FSMTConfig,
FunnelConfig,
GPT2Config,
HubertConfig,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
LukeConfig,
LxmertConfig,
M2M100Config,
MarianConfig,
MBartConfig,
MobileBertConfig,
MPNetConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
ProphetNetConfig,
RagConfig,
ReformerConfig,
RetriBertConfig,
RobertaConfig,
RoFormerConfig,
Speech2TextConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
TransfoXLConfig,
Wav2Vec2Config,
XLMConfig,
XLMProphetNetConfig,
XLMRobertaConfig,
XLNetConfig,
config_class_to_model_type,
replace_list_option_in_docstrings,
)
if is_sentencepiece_available():
from ..albert.tokenization_albert import AlbertTokenizer
from ..barthez.tokenization_barthez import BarthezTokenizer
from ..bert_generation.tokenization_bert_generation import BertGenerationTokenizer
from ..big_bird.tokenization_big_bird import BigBirdTokenizer
from ..camembert.tokenization_camembert import CamembertTokenizer
from ..cpm.tokenization_cpm import CpmTokenizer
from ..deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer
from ..m2m_100 import M2M100Tokenizer
from ..marian.tokenization_marian import MarianTokenizer
from ..mbart.tokenization_mbart import MBartTokenizer
from ..mbart.tokenization_mbart50 import MBart50Tokenizer
from ..mt5 import MT5Tokenizer
from ..pegasus.tokenization_pegasus import PegasusTokenizer
from ..reformer.tokenization_reformer import ReformerTokenizer
from ..speech_to_text import Speech2TextTokenizer
from ..t5.tokenization_t5 import T5Tokenizer
from ..xlm_prophetnet.tokenization_xlm_prophetnet import XLMProphetNetTokenizer
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
from ..xlnet.tokenization_xlnet import XLNetTokenizer
else:
AlbertTokenizer = None
BarthezTokenizer = None
BertGenerationTokenizer = None
BigBirdTokenizer = None
CamembertTokenizer = None
CpmTokenizer = None
DebertaV2Tokenizer = None
MarianTokenizer = None
MBartTokenizer = None
MBart50Tokenizer = None
MT5Tokenizer = None
PegasusTokenizer = None
ReformerTokenizer = None
T5Tokenizer = None
XLMRobertaTokenizer = None
XLNetTokenizer = None
XLMProphetNetTokenizer = None
M2M100Tokenizer = None
Speech2TextTokenizer = None
if is_tokenizers_available():
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ..albert.tokenization_albert_fast import AlbertTokenizerFast
from ..bart.tokenization_bart_fast import BartTokenizerFast
from ..barthez.tokenization_barthez_fast import BarthezTokenizerFast
from ..bert.tokenization_bert_fast import BertTokenizerFast
from ..big_bird.tokenization_big_bird_fast import BigBirdTokenizerFast
from ..camembert.tokenization_camembert_fast import CamembertTokenizerFast
from ..convbert.tokenization_convbert_fast import ConvBertTokenizerFast
from ..cpm.tokenization_cpm_fast import CpmTokenizerFast
from ..deberta.tokenization_deberta_fast import DebertaTokenizerFast
from ..distilbert.tokenization_distilbert_fast import DistilBertTokenizerFast
from ..dpr.tokenization_dpr_fast import DPRQuestionEncoderTokenizerFast
from ..electra.tokenization_electra_fast import ElectraTokenizerFast
from ..funnel.tokenization_funnel_fast import FunnelTokenizerFast
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from ..herbert.tokenization_herbert_fast import HerbertTokenizerFast
from ..layoutlm.tokenization_layoutlm_fast import LayoutLMTokenizerFast
from ..led.tokenization_led_fast import LEDTokenizerFast
from ..longformer.tokenization_longformer_fast import LongformerTokenizerFast
from ..lxmert.tokenization_lxmert_fast import LxmertTokenizerFast
from ..mbart.tokenization_mbart50_fast import MBart50TokenizerFast
from ..mbart.tokenization_mbart_fast import MBartTokenizerFast
from ..mobilebert.tokenization_mobilebert_fast import MobileBertTokenizerFast
from ..mpnet.tokenization_mpnet_fast import MPNetTokenizerFast
from ..mt5 import MT5TokenizerFast
from ..openai.tokenization_openai_fast import OpenAIGPTTokenizerFast
from ..pegasus.tokenization_pegasus_fast import PegasusTokenizerFast
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
from ..t5.tokenization_t5_fast import T5TokenizerFast
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
from ..xlnet.tokenization_xlnet_fast import XLNetTokenizerFast
else:
AlbertTokenizerFast = None
BartTokenizerFast = None
BarthezTokenizerFast = None
BertTokenizerFast = None
BigBirdTokenizerFast = None
CamembertTokenizerFast = None
ConvBertTokenizerFast = None
CpmTokenizerFast = None
DebertaTokenizerFast = None
DistilBertTokenizerFast = None
DPRQuestionEncoderTokenizerFast = None
ElectraTokenizerFast = None
FunnelTokenizerFast = None
GPT2TokenizerFast = None
HerbertTokenizerFast = None
LayoutLMTokenizerFast = None
LEDTokenizerFast = None
LongformerTokenizerFast = None
LxmertTokenizerFast = None
MBartTokenizerFast = None
MBart50TokenizerFast = None
MobileBertTokenizerFast = None
MPNetTokenizerFast = None
MT5TokenizerFast = None
OpenAIGPTTokenizerFast = None
PegasusTokenizerFast = None
ReformerTokenizerFast = None
RetriBertTokenizerFast = None
RobertaTokenizerFast = None
RoFormerTokenizerFast = None
SqueezeBertTokenizerFast = None
T5TokenizerFast = None
XLMRobertaTokenizerFast = None
XLNetTokenizerFast = None
PreTrainedTokenizerFast = None
logger = logging.get_logger(__name__)
TOKENIZER_MAPPING = OrderedDict(
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, AlbertTokenizerFast)),
(CamembertConfig, (CamembertTokenizer, CamembertTokenizerFast)),
(PegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)),
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotConfig, (BlenderbotTokenizer, None)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(ReformerConfig, (ReformerTokenizer, ReformerTokenizerFast)),
(ElectraConfig, (ElectraTokenizer, ElectraTokenizerFast)),
(FunnelConfig, (FunnelTokenizer, FunnelTokenizerFast)),
(LxmertConfig, (LxmertTokenizer, LxmertTokenizerFast)),
(LayoutLMConfig, (LayoutLMTokenizer, LayoutLMTokenizerFast)),
(DPRConfig, (DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast)),
(SqueezeBertConfig, (SqueezeBertTokenizer, SqueezeBertTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
(TransfoXLConfig, (TransfoXLTokenizer, None)),
(XLNetConfig, (XLNetTokenizer, XLNetTokenizerFast)),
(FlaubertConfig, (FlaubertTokenizer, None)),
(XLMConfig, (XLMTokenizer, None)),
(CTRLConfig, (CTRLTokenizer, None)),
(FSMTConfig, (FSMTTokenizer, None)),
(BertGenerationConfig, (BertGenerationTokenizer, None)),
(DebertaConfig, (DebertaTokenizer, DebertaTokenizerFast)),
(DebertaV2Config, (DebertaV2Tokenizer, None)),
(RagConfig, (RagTokenizer, None)),
(XLMProphetNetConfig, (XLMProphetNetTokenizer, None)),
(Speech2TextConfig, (Speech2TextTokenizer, None)),
(M2M100Config, (M2M100Tokenizer, None)),
(ProphetNetConfig, (ProphetNetTokenizer, None)),
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
(TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
(BigBirdConfig, (BigBirdTokenizer, BigBirdTokenizerFast)),
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
(HubertConfig, (Wav2Vec2CTCTokenizer, None)),
(GPTNeoConfig, (GPT2Tokenizer, GPT2TokenizerFast)),
(LukeConfig, (LukeTokenizer, None)),
(BigBirdPegasusConfig, (PegasusTokenizer, PegasusTokenizerFast)),
(CanineConfig, (CanineTokenizer, None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
(
"t5",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mt5",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"albert",
(
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"camembert",
(
"CamembertTokenizer" if is_sentencepiece_available() else None,
"CamembertTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"pegasus",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart",
(
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"xlm-roberta",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
("DPRQuestionEncoderTokenizer", "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None),
),
("squeezebert", ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None)),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
("flaubert", ("FlaubertTokenizer", None)),
("xlm", ("XLMTokenizer", None)),
("ctrl", ("CTRLTokenizer", None)),
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta-v2", ("DebertaV2Tokenizer" if is_sentencepiece_available() else None, None)),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"big_bird",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
),
),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("byt5", ("ByT5Tokenizer", None)),
(
"cpm",
(
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
(
"barthez",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"mbart50",
(
"MBart50Tokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
]
)
# For tokenizers which are not directly mapped from a config
NO_CONFIG_TOKENIZER = [
BertJapaneseTokenizer,
BertweetTokenizer,
ByT5Tokenizer,
CpmTokenizer,
CpmTokenizerFast,
HerbertTokenizer,
HerbertTokenizerFast,
PhobertTokenizer,
BarthezTokenizer,
BarthezTokenizerFast,
MBart50Tokenizer,
MBart50TokenizerFast,
PreTrainedTokenizerFast,
]
TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
SLOW_TOKENIZER_MAPPING = {
k: (v[0] if v[0] is not None else v[1])
for k, v in TOKENIZER_MAPPING.items()
if (v[0] is not None or v[1] is not None)
}
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
def tokenizer_class_from_name(class_name: str):
all_tokenizer_classes = (
[v[0] for v in TOKENIZER_MAPPING.values() if v[0] is not None]
+ [v[1] for v in TOKENIZER_MAPPING.values() if v[1] is not None]
+ [v for v in NO_CONFIG_TOKENIZER if v is not None]
)
for c in all_tokenizer_classes:
if c.__name__ == class_name:
return c
if class_name == "PreTrainedTokenizerFast":
return PreTrainedTokenizerFast
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers:
break
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
def get_tokenizer_config(
@ -454,7 +322,7 @@ class AutoTokenizer:
)
@classmethod
@replace_list_option_in_docstrings(SLOW_TOKENIZER_MAPPING)
@replace_list_option_in_docstrings(TOKENIZER_MAPPING_NAMES)
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
r"""
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
@ -565,7 +433,8 @@ class AutoTokenizer:
)
config = config.encoder
if type(config) in TOKENIZER_MAPPING.keys():
model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None:
tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

View File

@ -33,10 +33,8 @@ _import_structure = {
if is_sentencepiece_available():
_import_structure["tokenization_mbart"] = ["MBartTokenizer"]
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available():
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
if is_torch_available():
@ -72,10 +70,8 @@ if TYPE_CHECKING:
if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
from .tokenization_mbart_fast import MBartTokenizerFast
if is_torch_available():

View File

@ -0,0 +1,42 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
if is_sentencepiece_available():
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available():
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
if TYPE_CHECKING:
if is_sentencepiece_available():
from .tokenization_mbart50 import MBart50Tokenizer
if is_tokenizers_available():
from .tokenization_mbart50_fast import MBart50TokenizerFast
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -1781,25 +1781,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if config_tokenizer_class is None:
# Third attempt. If we have not yet found the original type of the tokenizer,
# we are loading we see if we can infer it from the type of the configuration file
from .models.auto.configuration_auto import CONFIG_MAPPING # tests_ignore
from .models.auto.tokenization_auto import TOKENIZER_MAPPING # tests_ignore
from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore
if hasattr(config, "model_type"):
config_class = CONFIG_MAPPING.get(config.model_type)
model_type = config.model_type
else:
# Fallback: use pattern matching on the string.
config_class = None
for pattern, config_class_tmp in CONFIG_MAPPING.items():
model_type = None
for pattern in TOKENIZER_MAPPING_NAMES.keys():
if pattern in str(pretrained_model_name_or_path):
config_class = config_class_tmp
model_type = pattern
break
if config_class in TOKENIZER_MAPPING.keys():
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING[config_class]
if config_tokenizer_class is not None:
config_tokenizer_class = config_tokenizer_class.__name__
else:
config_tokenizer_class = config_tokenizer_class_fast.__name__
if model_type is not None:
config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES[model_type]
if config_tokenizer_class is None:
config_tokenizer_class = config_tokenizer_class_fast
if config_tokenizer_class is not None:
if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):

View File

@ -74,6 +74,7 @@ from .file_utils import (
)
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
@ -125,7 +126,6 @@ from .trainer_utils import (
)
from .training_args import ParallelMode, TrainingArguments
from .utils import logging
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
_is_torch_generator_available = False

View File

@ -191,7 +191,7 @@ class LxmertTokenizerFast:
requires_backends(cls, ["tokenizers"])
class MBart50TokenizerFast:
class MBartTokenizerFast:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
@ -200,7 +200,7 @@ class MBart50TokenizerFast:
requires_backends(cls, ["tokenizers"])
class MBartTokenizerFast:
class MBart50TokenizerFast:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])

View File

@ -1,374 +0,0 @@
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify: models/auto/modeling_auto.py
# 2. run: python utils/class_mapping_update.py
from collections import OrderedDict
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForQuestionAnswering"),
("CanineConfig", "CanineForQuestionAnswering"),
("RoFormerConfig", "RoFormerForQuestionAnswering"),
("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"),
("BigBirdConfig", "BigBirdForQuestionAnswering"),
("ConvBertConfig", "ConvBertForQuestionAnswering"),
("LEDConfig", "LEDForQuestionAnswering"),
("DistilBertConfig", "DistilBertForQuestionAnswering"),
("AlbertConfig", "AlbertForQuestionAnswering"),
("CamembertConfig", "CamembertForQuestionAnswering"),
("BartConfig", "BartForQuestionAnswering"),
("MBartConfig", "MBartForQuestionAnswering"),
("LongformerConfig", "LongformerForQuestionAnswering"),
("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"),
("RobertaConfig", "RobertaForQuestionAnswering"),
("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"),
("BertConfig", "BertForQuestionAnswering"),
("XLNetConfig", "XLNetForQuestionAnsweringSimple"),
("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"),
("MegatronBertConfig", "MegatronBertForQuestionAnswering"),
("MobileBertConfig", "MobileBertForQuestionAnswering"),
("XLMConfig", "XLMForQuestionAnsweringSimple"),
("ElectraConfig", "ElectraForQuestionAnswering"),
("ReformerConfig", "ReformerForQuestionAnswering"),
("FunnelConfig", "FunnelForQuestionAnswering"),
("LxmertConfig", "LxmertForQuestionAnswering"),
("MPNetConfig", "MPNetForQuestionAnswering"),
("DebertaConfig", "DebertaForQuestionAnswering"),
("DebertaV2Config", "DebertaV2ForQuestionAnswering"),
("IBertConfig", "IBertForQuestionAnswering"),
]
)
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForCausalLM"),
("RoFormerConfig", "RoFormerForCausalLM"),
("BigBirdPegasusConfig", "BigBirdPegasusForCausalLM"),
("GPTNeoConfig", "GPTNeoForCausalLM"),
("BigBirdConfig", "BigBirdForCausalLM"),
("CamembertConfig", "CamembertForCausalLM"),
("XLMRobertaConfig", "XLMRobertaForCausalLM"),
("RobertaConfig", "RobertaForCausalLM"),
("BertConfig", "BertLMHeadModel"),
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
("GPT2Config", "GPT2LMHeadModel"),
("TransfoXLConfig", "TransfoXLLMHeadModel"),
("XLNetConfig", "XLNetLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("CTRLConfig", "CTRLLMHeadModel"),
("ReformerConfig", "ReformerModelWithLMHead"),
("BertGenerationConfig", "BertGenerationDecoder"),
("XLMProphetNetConfig", "XLMProphetNetForCausalLM"),
("ProphetNetConfig", "ProphetNetForCausalLM"),
("BartConfig", "BartForCausalLM"),
("MBartConfig", "MBartForCausalLM"),
("PegasusConfig", "PegasusForCausalLM"),
("MarianConfig", "MarianForCausalLM"),
("BlenderbotConfig", "BlenderbotForCausalLM"),
("BlenderbotSmallConfig", "BlenderbotSmallForCausalLM"),
("MegatronBertConfig", "MegatronBertForCausalLM"),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("ViTConfig", "ViTForImageClassification"),
("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"),
("BeitConfig", "BeitForImageClassification"),
]
)
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMaskedLM"),
("RoFormerConfig", "RoFormerForMaskedLM"),
("BigBirdConfig", "BigBirdForMaskedLM"),
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
("ConvBertConfig", "ConvBertForMaskedLM"),
("LayoutLMConfig", "LayoutLMForMaskedLM"),
("DistilBertConfig", "DistilBertForMaskedLM"),
("AlbertConfig", "AlbertForMaskedLM"),
("BartConfig", "BartForConditionalGeneration"),
("MBartConfig", "MBartForConditionalGeneration"),
("CamembertConfig", "CamembertForMaskedLM"),
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
("LongformerConfig", "LongformerForMaskedLM"),
("RobertaConfig", "RobertaForMaskedLM"),
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
("BertConfig", "BertForMaskedLM"),
("MegatronBertConfig", "MegatronBertForMaskedLM"),
("MobileBertConfig", "MobileBertForMaskedLM"),
("FlaubertConfig", "FlaubertWithLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("ElectraConfig", "ElectraForMaskedLM"),
("ReformerConfig", "ReformerForMaskedLM"),
("FunnelConfig", "FunnelForMaskedLM"),
("MPNetConfig", "MPNetForMaskedLM"),
("TapasConfig", "TapasForMaskedLM"),
("DebertaConfig", "DebertaForMaskedLM"),
("DebertaV2Config", "DebertaV2ForMaskedLM"),
("IBertConfig", "IBertForMaskedLM"),
]
)
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMultipleChoice"),
("CanineConfig", "CanineForMultipleChoice"),
("RoFormerConfig", "RoFormerForMultipleChoice"),
("BigBirdConfig", "BigBirdForMultipleChoice"),
("ConvBertConfig", "ConvBertForMultipleChoice"),
("CamembertConfig", "CamembertForMultipleChoice"),
("ElectraConfig", "ElectraForMultipleChoice"),
("XLMRobertaConfig", "XLMRobertaForMultipleChoice"),
("LongformerConfig", "LongformerForMultipleChoice"),
("RobertaConfig", "RobertaForMultipleChoice"),
("SqueezeBertConfig", "SqueezeBertForMultipleChoice"),
("BertConfig", "BertForMultipleChoice"),
("DistilBertConfig", "DistilBertForMultipleChoice"),
("MegatronBertConfig", "MegatronBertForMultipleChoice"),
("MobileBertConfig", "MobileBertForMultipleChoice"),
("XLNetConfig", "XLNetForMultipleChoice"),
("AlbertConfig", "AlbertForMultipleChoice"),
("XLMConfig", "XLMForMultipleChoice"),
("FlaubertConfig", "FlaubertForMultipleChoice"),
("FunnelConfig", "FunnelForMultipleChoice"),
("MPNetConfig", "MPNetForMultipleChoice"),
("IBertConfig", "IBertForMultipleChoice"),
]
)
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("BertConfig", "BertForNextSentencePrediction"),
("MegatronBertConfig", "MegatronBertForNextSentencePrediction"),
("MobileBertConfig", "MobileBertForNextSentencePrediction"),
]
)
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("DetrConfig", "DetrForObjectDetection"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
("M2M100Config", "M2M100ForConditionalGeneration"),
("LEDConfig", "LEDForConditionalGeneration"),
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
("MT5Config", "MT5ForConditionalGeneration"),
("T5Config", "T5ForConditionalGeneration"),
("PegasusConfig", "PegasusForConditionalGeneration"),
("MarianConfig", "MarianMTModel"),
("MBartConfig", "MBartForConditionalGeneration"),
("BlenderbotConfig", "BlenderbotForConditionalGeneration"),
("BartConfig", "BartForConditionalGeneration"),
("FSMTConfig", "FSMTForConditionalGeneration"),
("EncoderDecoderConfig", "EncoderDecoderModel"),
("XLMProphetNetConfig", "XLMProphetNetForConditionalGeneration"),
("ProphetNetConfig", "ProphetNetForConditionalGeneration"),
]
)
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForSequenceClassification"),
("CanineConfig", "CanineForSequenceClassification"),
("RoFormerConfig", "RoFormerForSequenceClassification"),
("BigBirdPegasusConfig", "BigBirdPegasusForSequenceClassification"),
("BigBirdConfig", "BigBirdForSequenceClassification"),
("ConvBertConfig", "ConvBertForSequenceClassification"),
("LEDConfig", "LEDForSequenceClassification"),
("DistilBertConfig", "DistilBertForSequenceClassification"),
("AlbertConfig", "AlbertForSequenceClassification"),
("CamembertConfig", "CamembertForSequenceClassification"),
("XLMRobertaConfig", "XLMRobertaForSequenceClassification"),
("MBartConfig", "MBartForSequenceClassification"),
("BartConfig", "BartForSequenceClassification"),
("LongformerConfig", "LongformerForSequenceClassification"),
("RobertaConfig", "RobertaForSequenceClassification"),
("SqueezeBertConfig", "SqueezeBertForSequenceClassification"),
("LayoutLMConfig", "LayoutLMForSequenceClassification"),
("BertConfig", "BertForSequenceClassification"),
("XLNetConfig", "XLNetForSequenceClassification"),
("MegatronBertConfig", "MegatronBertForSequenceClassification"),
("MobileBertConfig", "MobileBertForSequenceClassification"),
("FlaubertConfig", "FlaubertForSequenceClassification"),
("XLMConfig", "XLMForSequenceClassification"),
("ElectraConfig", "ElectraForSequenceClassification"),
("FunnelConfig", "FunnelForSequenceClassification"),
("DebertaConfig", "DebertaForSequenceClassification"),
("DebertaV2Config", "DebertaV2ForSequenceClassification"),
("GPT2Config", "GPT2ForSequenceClassification"),
("GPTNeoConfig", "GPTNeoForSequenceClassification"),
("OpenAIGPTConfig", "OpenAIGPTForSequenceClassification"),
("ReformerConfig", "ReformerForSequenceClassification"),
("CTRLConfig", "CTRLForSequenceClassification"),
("TransfoXLConfig", "TransfoXLForSequenceClassification"),
("MPNetConfig", "MPNetForSequenceClassification"),
("TapasConfig", "TapasForSequenceClassification"),
("IBertConfig", "IBertForSequenceClassification"),
]
)
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("TapasConfig", "TapasForQuestionAnswering"),
]
)
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForTokenClassification"),
("CanineConfig", "CanineForTokenClassification"),
("RoFormerConfig", "RoFormerForTokenClassification"),
("BigBirdConfig", "BigBirdForTokenClassification"),
("ConvBertConfig", "ConvBertForTokenClassification"),
("LayoutLMConfig", "LayoutLMForTokenClassification"),
("DistilBertConfig", "DistilBertForTokenClassification"),
("CamembertConfig", "CamembertForTokenClassification"),
("FlaubertConfig", "FlaubertForTokenClassification"),
("XLMConfig", "XLMForTokenClassification"),
("XLMRobertaConfig", "XLMRobertaForTokenClassification"),
("LongformerConfig", "LongformerForTokenClassification"),
("RobertaConfig", "RobertaForTokenClassification"),
("SqueezeBertConfig", "SqueezeBertForTokenClassification"),
("BertConfig", "BertForTokenClassification"),
("MegatronBertConfig", "MegatronBertForTokenClassification"),
("MobileBertConfig", "MobileBertForTokenClassification"),
("XLNetConfig", "XLNetForTokenClassification"),
("AlbertConfig", "AlbertForTokenClassification"),
("ElectraConfig", "ElectraForTokenClassification"),
("FunnelConfig", "FunnelForTokenClassification"),
("MPNetConfig", "MPNetForTokenClassification"),
("DebertaConfig", "DebertaForTokenClassification"),
("DebertaV2Config", "DebertaV2ForTokenClassification"),
("IBertConfig", "IBertForTokenClassification"),
]
)
MODEL_MAPPING_NAMES = OrderedDict(
[
("BeitConfig", "BeitModel"),
("RemBertConfig", "RemBertModel"),
("VisualBertConfig", "VisualBertModel"),
("CanineConfig", "CanineModel"),
("RoFormerConfig", "RoFormerModel"),
("CLIPConfig", "CLIPModel"),
("BigBirdPegasusConfig", "BigBirdPegasusModel"),
("DeiTConfig", "DeiTModel"),
("LukeConfig", "LukeModel"),
("DetrConfig", "DetrModel"),
("GPTNeoConfig", "GPTNeoModel"),
("BigBirdConfig", "BigBirdModel"),
("Speech2TextConfig", "Speech2TextModel"),
("ViTConfig", "ViTModel"),
("Wav2Vec2Config", "Wav2Vec2Model"),
("HubertConfig", "HubertModel"),
("M2M100Config", "M2M100Model"),
("ConvBertConfig", "ConvBertModel"),
("LEDConfig", "LEDModel"),
("BlenderbotSmallConfig", "BlenderbotSmallModel"),
("RetriBertConfig", "RetriBertModel"),
("MT5Config", "MT5Model"),
("T5Config", "T5Model"),
("PegasusConfig", "PegasusModel"),
("MarianConfig", "MarianModel"),
("MBartConfig", "MBartModel"),
("BlenderbotConfig", "BlenderbotModel"),
("DistilBertConfig", "DistilBertModel"),
("AlbertConfig", "AlbertModel"),
("CamembertConfig", "CamembertModel"),
("XLMRobertaConfig", "XLMRobertaModel"),
("BartConfig", "BartModel"),
("LongformerConfig", "LongformerModel"),
("RobertaConfig", "RobertaModel"),
("LayoutLMConfig", "LayoutLMModel"),
("SqueezeBertConfig", "SqueezeBertModel"),
("BertConfig", "BertModel"),
("OpenAIGPTConfig", "OpenAIGPTModel"),
("GPT2Config", "GPT2Model"),
("MegatronBertConfig", "MegatronBertModel"),
("MobileBertConfig", "MobileBertModel"),
("TransfoXLConfig", "TransfoXLModel"),
("XLNetConfig", "XLNetModel"),
("FlaubertConfig", "FlaubertModel"),
("FSMTConfig", "FSMTModel"),
("XLMConfig", "XLMModel"),
("CTRLConfig", "CTRLModel"),
("ElectraConfig", "ElectraModel"),
("ReformerConfig", "ReformerModel"),
("FunnelConfig", "('FunnelModel', 'FunnelBaseModel')"),
("LxmertConfig", "LxmertModel"),
("BertGenerationConfig", "BertGenerationEncoder"),
("DebertaConfig", "DebertaModel"),
("DebertaV2Config", "DebertaV2Model"),
("DPRConfig", "DPRQuestionEncoder"),
("XLMProphetNetConfig", "XLMProphetNetModel"),
("ProphetNetConfig", "ProphetNetModel"),
("MPNetConfig", "MPNetModel"),
("TapasConfig", "TapasModel"),
("IBertConfig", "IBertModel"),
]
)
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
("RemBertConfig", "RemBertForMaskedLM"),
("RoFormerConfig", "RoFormerForMaskedLM"),
("BigBirdPegasusConfig", "BigBirdPegasusForConditionalGeneration"),
("GPTNeoConfig", "GPTNeoForCausalLM"),
("BigBirdConfig", "BigBirdForMaskedLM"),
("Speech2TextConfig", "Speech2TextForConditionalGeneration"),
("Wav2Vec2Config", "Wav2Vec2ForMaskedLM"),
("M2M100Config", "M2M100ForConditionalGeneration"),
("ConvBertConfig", "ConvBertForMaskedLM"),
("LEDConfig", "LEDForConditionalGeneration"),
("BlenderbotSmallConfig", "BlenderbotSmallForConditionalGeneration"),
("LayoutLMConfig", "LayoutLMForMaskedLM"),
("T5Config", "T5ForConditionalGeneration"),
("DistilBertConfig", "DistilBertForMaskedLM"),
("AlbertConfig", "AlbertForMaskedLM"),
("CamembertConfig", "CamembertForMaskedLM"),
("XLMRobertaConfig", "XLMRobertaForMaskedLM"),
("MarianConfig", "MarianMTModel"),
("FSMTConfig", "FSMTForConditionalGeneration"),
("BartConfig", "BartForConditionalGeneration"),
("LongformerConfig", "LongformerForMaskedLM"),
("RobertaConfig", "RobertaForMaskedLM"),
("SqueezeBertConfig", "SqueezeBertForMaskedLM"),
("BertConfig", "BertForMaskedLM"),
("OpenAIGPTConfig", "OpenAIGPTLMHeadModel"),
("GPT2Config", "GPT2LMHeadModel"),
("MegatronBertConfig", "MegatronBertForCausalLM"),
("MobileBertConfig", "MobileBertForMaskedLM"),
("TransfoXLConfig", "TransfoXLLMHeadModel"),
("XLNetConfig", "XLNetLMHeadModel"),
("FlaubertConfig", "FlaubertWithLMHeadModel"),
("XLMConfig", "XLMWithLMHeadModel"),
("CTRLConfig", "CTRLLMHeadModel"),
("ElectraConfig", "ElectraForMaskedLM"),
("EncoderDecoderConfig", "EncoderDecoderModel"),
("ReformerConfig", "ReformerModelWithLMHead"),
("FunnelConfig", "FunnelForMaskedLM"),
("MPNetConfig", "MPNetForMaskedLM"),
("TapasConfig", "TapasForMaskedLM"),
("DebertaConfig", "DebertaForMaskedLM"),
("DebertaV2Config", "DebertaV2ForMaskedLM"),
("IBertConfig", "IBertForMaskedLM"),
]
)

View File

@ -17,7 +17,7 @@
{% if cookiecutter.is_encoder_decoder_model == "False" %}
import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
@ -1484,7 +1484,7 @@ from ...file_utils import (
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
@ -2162,7 +2162,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)

View File

@ -172,17 +172,12 @@
# To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here"
# Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Config"),
# End.
# Below: "# Add archive maps here"
# Replace with:
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
# End.
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with:
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP"),
# End.
# Below: "# Add full (and cased) model names here"
@ -193,75 +188,47 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca
# To replace in: "src/transformers/models/auto/modeling_auto.py" if generating PyTorch
# Below: "from .configuration_auto import ("
# Replace with:
{{cookiecutter.camelcase_modelname}}Config,
# End.
# Below: "# Add modeling imports here"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
{% endif -%}
# End.
# Below: "# Base model mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Model),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Model"),
# End.
# Below: "# Model with LM heads mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else %}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
# Below: "# Model for Causal LM mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForCausalLM"),
# End.
# Below: "# Model for Masked LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else -%}
{% endif -%}
# End.
# Below: "# Model for Sequence Classification mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForSequenceClassification),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
# End.
# Below: "# Model for Question Answering mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
# End.
# Below: "# Model for Token Classification mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForTokenClassification),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
{% else -%}
{% endif -%}
# End.
@ -269,7 +236,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Below: "# Model for Multiple Choice mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMultipleChoice),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
{% else -%}
{% endif -%}
# End.
@ -278,54 +245,29 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else %}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
# To replace in: "src/transformers/models/auto/modeling_tf_auto.py" if generating TensorFlow
# Below: "from .configuration_auto import ("
# Replace with:
{{cookiecutter.camelcase_modelname}}Config,
# End.
# Below: "# Add modeling imports here"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Model,
)
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model,
)
{% endif -%}
# End.
# Below: "# Base model mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}Model),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}Model"),
# End.
# Below: "# Model with LM heads mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else %}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
# Below: "# Model for Causal LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForCausalLM),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForCausalLM"),
{% else -%}
{% endif -%}
# End.
@ -333,7 +275,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Masked LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else -%}
{% endif -%}
# End.
@ -341,7 +283,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Sequence Classification mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
{% else -%}
{% endif -%}
# End.
@ -349,7 +291,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Question Answering mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
{% else -%}
{% endif -%}
# End.
@ -357,7 +299,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Token Classification mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForTokenClassification),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
{% else -%}
{% endif -%}
# End.
@ -365,7 +307,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Multiple Choice mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
{% else -%}
{% endif -%}
# End.
@ -374,7 +316,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else %}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.

View File

@ -23,7 +23,8 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
from transformers.models.mbart import MBartForConditionalGeneration
from transformers.models.mbart50 import MBart50TokenizerFast
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):

View File

@ -306,17 +306,17 @@ def get_all_auto_configured_models():
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
if is_tf_available():
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
if is_flax_available():
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result]
return [cls for cls in result]
def ignore_unautoclassed(model_name):

View File

@ -87,12 +87,13 @@ def get_model_table_from_auto_modules():
transformers = spec.loader.load_module()
# Dictionary model names to config.
config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
}
model_name_to_prefix = {
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
name: config_maping_names[code]
for code, name in transformers.MODEL_NAMES_MAPPING.items()
if code in config_maping_names
}
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers = collections.defaultdict(bool)

View File

@ -1,106 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script remaps classes to class strings so that it's quick to load such maps and not require
# loading all possible modeling files
#
# it can be extended to auto-generate other dicts that are needed at runtime
import os
import sys
from os.path import abspath, dirname, join
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
src = "src/transformers/models/auto/modeling_auto.py"
dst = "src/transformers/utils/modeling_auto_mapping.py"
if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst):
# speed things up by only running this script if the src is newer than dst
sys.exit(0)
# only load if needed
from transformers.models.auto.modeling_auto import ( # noqa
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
)
# Those constants don't have a name attribute, so we need to define it manually
mappings = {
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
"MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING,
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"MODEL_MAPPING": MODEL_MAPPING,
"MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING,
}
def get_name(value):
if isinstance(value, tuple):
return tuple(get_name(o) for o in value)
return value.__name__
content = [
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
"# 1. modify: models/auto/modeling_auto.py",
"# 2. run: python utils/class_mapping_update.py",
"from collections import OrderedDict",
"",
]
for name, mapping in mappings.items():
entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()])
content += [
"",
f"{name}_NAMES = OrderedDict(",
" [",
entries,
" ]",
")",
"",
]
print(f"Updating {dst}")
with open(dst, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content))