mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add all XxxPreTrainedModel to the main init (#12314)
* Add all XxxPreTrainedModel to the main init * Add to template * Add to template bis * Add FlaxT5
This commit is contained in:
parent
53c60babe4
commit
9eda6b52e2
@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available():
|
||||
"DetrForObjectDetection",
|
||||
"DetrForSegmentation",
|
||||
"DetrModel",
|
||||
"DetrPreTrainedModel",
|
||||
]
|
||||
)
|
||||
else:
|
||||
@ -570,6 +571,7 @@ if is_torch_available():
|
||||
[
|
||||
"BertGenerationDecoder",
|
||||
"BertGenerationEncoder",
|
||||
"BertGenerationPreTrainedModel",
|
||||
"load_tf_weights_in_bert_generation",
|
||||
]
|
||||
)
|
||||
@ -597,6 +599,7 @@ if is_torch_available():
|
||||
"BigBirdPegasusForQuestionAnswering",
|
||||
"BigBirdPegasusForSequenceClassification",
|
||||
"BigBirdPegasusModel",
|
||||
"BigBirdPegasusPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.blenderbot"].extend(
|
||||
@ -605,6 +608,7 @@ if is_torch_available():
|
||||
"BlenderbotForCausalLM",
|
||||
"BlenderbotForConditionalGeneration",
|
||||
"BlenderbotModel",
|
||||
"BlenderbotPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.blenderbot_small"].extend(
|
||||
@ -613,6 +617,7 @@ if is_torch_available():
|
||||
"BlenderbotSmallForCausalLM",
|
||||
"BlenderbotSmallForConditionalGeneration",
|
||||
"BlenderbotSmallModel",
|
||||
"BlenderbotSmallPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.camembert"].extend(
|
||||
@ -754,6 +759,7 @@ if is_torch_available():
|
||||
"FunnelForSequenceClassification",
|
||||
"FunnelForTokenClassification",
|
||||
"FunnelModel",
|
||||
"FunnelPreTrainedModel",
|
||||
"load_tf_weights_in_funnel",
|
||||
]
|
||||
)
|
||||
@ -805,6 +811,7 @@ if is_torch_available():
|
||||
"LayoutLMForSequenceClassification",
|
||||
"LayoutLMForTokenClassification",
|
||||
"LayoutLMModel",
|
||||
"LayoutLMPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.led"].extend(
|
||||
@ -814,6 +821,7 @@ if is_torch_available():
|
||||
"LEDForQuestionAnswering",
|
||||
"LEDForSequenceClassification",
|
||||
"LEDModel",
|
||||
"LEDPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.longformer"].extend(
|
||||
@ -825,6 +833,7 @@ if is_torch_available():
|
||||
"LongformerForSequenceClassification",
|
||||
"LongformerForTokenClassification",
|
||||
"LongformerModel",
|
||||
"LongformerPreTrainedModel",
|
||||
"LongformerSelfAttention",
|
||||
]
|
||||
)
|
||||
@ -854,6 +863,7 @@ if is_torch_available():
|
||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"M2M100ForConditionalGeneration",
|
||||
"M2M100Model",
|
||||
"M2M100PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
|
||||
@ -864,6 +874,7 @@ if is_torch_available():
|
||||
"MBartForQuestionAnswering",
|
||||
"MBartForSequenceClassification",
|
||||
"MBartModel",
|
||||
"MBartPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.megatron_bert"].extend(
|
||||
@ -878,6 +889,7 @@ if is_torch_available():
|
||||
"MegatronBertForSequenceClassification",
|
||||
"MegatronBertForTokenClassification",
|
||||
"MegatronBertModel",
|
||||
"MegatronBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
|
||||
@ -923,7 +935,7 @@ if is_torch_available():
|
||||
]
|
||||
)
|
||||
_import_structure["models.pegasus"].extend(
|
||||
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
|
||||
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.prophetnet"].extend(
|
||||
[
|
||||
@ -936,7 +948,9 @@ if is_torch_available():
|
||||
"ProphetNetPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"])
|
||||
_import_structure["models.rag"].extend(
|
||||
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
||||
)
|
||||
_import_structure["models.reformer"].extend(
|
||||
[
|
||||
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -947,6 +961,7 @@ if is_torch_available():
|
||||
"ReformerLayer",
|
||||
"ReformerModel",
|
||||
"ReformerModelWithLMHead",
|
||||
"ReformerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.retribert"].extend(
|
||||
@ -962,6 +977,7 @@ if is_torch_available():
|
||||
"RobertaForSequenceClassification",
|
||||
"RobertaForTokenClassification",
|
||||
"RobertaModel",
|
||||
"RobertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.roformer"].extend(
|
||||
@ -984,6 +1000,7 @@ if is_torch_available():
|
||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Speech2TextForConditionalGeneration",
|
||||
"Speech2TextModel",
|
||||
"Speech2TextPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.squeezebert"].extend(
|
||||
@ -1016,6 +1033,7 @@ if is_torch_available():
|
||||
"TapasForQuestionAnswering",
|
||||
"TapasForSequenceClassification",
|
||||
"TapasModel",
|
||||
"TapasPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.transfo_xl"].extend(
|
||||
@ -1197,9 +1215,11 @@ if is_tf_available():
|
||||
"TFBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
|
||||
_import_structure["models.blenderbot"].extend(
|
||||
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.blenderbot_small"].extend(
|
||||
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
|
||||
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.camembert"].extend(
|
||||
[
|
||||
@ -1281,6 +1301,7 @@ if is_tf_available():
|
||||
"TFFlaubertForSequenceClassification",
|
||||
"TFFlaubertForTokenClassification",
|
||||
"TFFlaubertModel",
|
||||
"TFFlaubertPreTrainedModel",
|
||||
"TFFlaubertWithLMHeadModel",
|
||||
]
|
||||
)
|
||||
@ -1295,6 +1316,7 @@ if is_tf_available():
|
||||
"TFFunnelForSequenceClassification",
|
||||
"TFFunnelForTokenClassification",
|
||||
"TFFunnelModel",
|
||||
"TFFunnelPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gpt2"].extend(
|
||||
@ -1329,6 +1351,7 @@ if is_tf_available():
|
||||
"TFLongformerForSequenceClassification",
|
||||
"TFLongformerForTokenClassification",
|
||||
"TFLongformerModel",
|
||||
"TFLongformerPreTrainedModel",
|
||||
"TFLongformerSelfAttention",
|
||||
]
|
||||
)
|
||||
@ -1342,8 +1365,10 @@ if is_tf_available():
|
||||
"TFLxmertVisualFeatureEncoder",
|
||||
]
|
||||
)
|
||||
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
|
||||
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
|
||||
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
|
||||
_import_structure["models.mbart"].extend(
|
||||
["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.mobilebert"].extend(
|
||||
[
|
||||
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1384,10 +1409,13 @@ if is_tf_available():
|
||||
"TFOpenAIGPTPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
|
||||
_import_structure["models.pegasus"].extend(
|
||||
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.rag"].extend(
|
||||
[
|
||||
"TFRagModel",
|
||||
"TFRagPreTrainedModel",
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
]
|
||||
@ -1538,6 +1566,7 @@ if is_flax_available():
|
||||
"FlaxBartForQuestionAnswering",
|
||||
"FlaxBartForSequenceClassification",
|
||||
"FlaxBartModel",
|
||||
"FlaxBartPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.bert"].extend(
|
||||
@ -1570,7 +1599,9 @@ if is_flax_available():
|
||||
"FlaxCLIPModel",
|
||||
"FlaxCLIPPreTrainedModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPTextPreTrainedModel",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxCLIPVisionPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.electra"].extend(
|
||||
@ -1585,7 +1616,7 @@ if is_flax_available():
|
||||
"FlaxElectraPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"])
|
||||
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
|
||||
_import_structure["models.roberta"].extend(
|
||||
[
|
||||
"FlaxRobertaForMaskedLM",
|
||||
@ -1597,8 +1628,8 @@ if is_flax_available():
|
||||
"FlaxRobertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
|
||||
else:
|
||||
from .utils import dummy_flax_objects
|
||||
|
||||
@ -1949,6 +1980,7 @@ if TYPE_CHECKING:
|
||||
DetrForObjectDetection,
|
||||
DetrForSegmentation,
|
||||
DetrModel,
|
||||
DetrPreTrainedModel,
|
||||
)
|
||||
else:
|
||||
from .utils.dummy_timm_objects import *
|
||||
@ -2074,6 +2106,7 @@ if TYPE_CHECKING:
|
||||
from .models.bert_generation import (
|
||||
BertGenerationDecoder,
|
||||
BertGenerationEncoder,
|
||||
BertGenerationPreTrainedModel,
|
||||
load_tf_weights_in_bert_generation,
|
||||
)
|
||||
from .models.big_bird import (
|
||||
@ -2097,18 +2130,21 @@ if TYPE_CHECKING:
|
||||
BigBirdPegasusForQuestionAnswering,
|
||||
BigBirdPegasusForSequenceClassification,
|
||||
BigBirdPegasusModel,
|
||||
BigBirdPegasusPreTrainedModel,
|
||||
)
|
||||
from .models.blenderbot import (
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotForCausalLM,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotModel,
|
||||
BlenderbotPreTrainedModel,
|
||||
)
|
||||
from .models.blenderbot_small import (
|
||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotSmallForCausalLM,
|
||||
BlenderbotSmallForConditionalGeneration,
|
||||
BlenderbotSmallModel,
|
||||
BlenderbotSmallPreTrainedModel,
|
||||
)
|
||||
from .models.camembert import (
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2226,6 +2262,7 @@ if TYPE_CHECKING:
|
||||
FunnelForSequenceClassification,
|
||||
FunnelForTokenClassification,
|
||||
FunnelModel,
|
||||
FunnelPreTrainedModel,
|
||||
load_tf_weights_in_funnel,
|
||||
)
|
||||
from .models.gpt2 import (
|
||||
@ -2267,6 +2304,7 @@ if TYPE_CHECKING:
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
LayoutLMPreTrainedModel,
|
||||
)
|
||||
from .models.led import (
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2274,6 +2312,7 @@ if TYPE_CHECKING:
|
||||
LEDForQuestionAnswering,
|
||||
LEDForSequenceClassification,
|
||||
LEDModel,
|
||||
LEDPreTrainedModel,
|
||||
)
|
||||
from .models.longformer import (
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2283,6 +2322,7 @@ if TYPE_CHECKING:
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForTokenClassification,
|
||||
LongformerModel,
|
||||
LongformerPreTrainedModel,
|
||||
LongformerSelfAttention,
|
||||
)
|
||||
from .models.luke import (
|
||||
@ -2302,7 +2342,12 @@ if TYPE_CHECKING:
|
||||
LxmertVisualFeatureEncoder,
|
||||
LxmertXLayer,
|
||||
)
|
||||
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model
|
||||
from .models.m2m_100 import (
|
||||
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
M2M100ForConditionalGeneration,
|
||||
M2M100Model,
|
||||
M2M100PreTrainedModel,
|
||||
)
|
||||
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||
from .models.mbart import (
|
||||
MBartForCausalLM,
|
||||
@ -2310,6 +2355,7 @@ if TYPE_CHECKING:
|
||||
MBartForQuestionAnswering,
|
||||
MBartForSequenceClassification,
|
||||
MBartModel,
|
||||
MBartPreTrainedModel,
|
||||
)
|
||||
from .models.megatron_bert import (
|
||||
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2322,6 +2368,7 @@ if TYPE_CHECKING:
|
||||
MegatronBertForSequenceClassification,
|
||||
MegatronBertForTokenClassification,
|
||||
MegatronBertModel,
|
||||
MegatronBertPreTrainedModel,
|
||||
)
|
||||
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
|
||||
from .models.mobilebert import (
|
||||
@ -2359,7 +2406,12 @@ if TYPE_CHECKING:
|
||||
OpenAIGPTPreTrainedModel,
|
||||
load_tf_weights_in_openai_gpt,
|
||||
)
|
||||
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
|
||||
from .models.pegasus import (
|
||||
PegasusForCausalLM,
|
||||
PegasusForConditionalGeneration,
|
||||
PegasusModel,
|
||||
PegasusPreTrainedModel,
|
||||
)
|
||||
from .models.prophetnet import (
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ProphetNetDecoder,
|
||||
@ -2369,7 +2421,7 @@ if TYPE_CHECKING:
|
||||
ProphetNetModel,
|
||||
ProphetNetPreTrainedModel,
|
||||
)
|
||||
from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from .models.reformer import (
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ReformerAttention,
|
||||
@ -2379,6 +2431,7 @@ if TYPE_CHECKING:
|
||||
ReformerLayer,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerPreTrainedModel,
|
||||
)
|
||||
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
|
||||
from .models.roberta import (
|
||||
@ -2390,6 +2443,7 @@ if TYPE_CHECKING:
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
RobertaModel,
|
||||
RobertaPreTrainedModel,
|
||||
)
|
||||
from .models.roformer import (
|
||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2408,6 +2462,7 @@ if TYPE_CHECKING:
|
||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Speech2TextForConditionalGeneration,
|
||||
Speech2TextModel,
|
||||
Speech2TextPreTrainedModel,
|
||||
)
|
||||
from .models.squeezebert import (
|
||||
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2434,6 +2489,7 @@ if TYPE_CHECKING:
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
TapasPreTrainedModel,
|
||||
)
|
||||
from .models.transfo_xl import (
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2600,8 +2656,16 @@ if TYPE_CHECKING:
|
||||
TFBertModel,
|
||||
TFBertPreTrainedModel,
|
||||
)
|
||||
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
||||
from .models.blenderbot import (
|
||||
TFBlenderbotForConditionalGeneration,
|
||||
TFBlenderbotModel,
|
||||
TFBlenderbotPreTrainedModel,
|
||||
)
|
||||
from .models.blenderbot_small import (
|
||||
TFBlenderbotSmallForConditionalGeneration,
|
||||
TFBlenderbotSmallModel,
|
||||
TFBlenderbotSmallPreTrainedModel,
|
||||
)
|
||||
from .models.camembert import (
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFCamembertForMaskedLM,
|
||||
@ -2669,6 +2733,7 @@ if TYPE_CHECKING:
|
||||
TFFlaubertForSequenceClassification,
|
||||
TFFlaubertForTokenClassification,
|
||||
TFFlaubertModel,
|
||||
TFFlaubertPreTrainedModel,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
from .models.funnel import (
|
||||
@ -2681,6 +2746,7 @@ if TYPE_CHECKING:
|
||||
TFFunnelForSequenceClassification,
|
||||
TFFunnelForTokenClassification,
|
||||
TFFunnelModel,
|
||||
TFFunnelPreTrainedModel,
|
||||
)
|
||||
from .models.gpt2 import (
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -2700,6 +2766,7 @@ if TYPE_CHECKING:
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerPreTrainedModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
from .models.lxmert import (
|
||||
@ -2710,8 +2777,8 @@ if TYPE_CHECKING:
|
||||
TFLxmertPreTrainedModel,
|
||||
TFLxmertVisualFeatureEncoder,
|
||||
)
|
||||
from .models.marian import TFMarianModel, TFMarianMTModel
|
||||
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||
from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
|
||||
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
|
||||
from .models.mobilebert import (
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFMobileBertForMaskedLM,
|
||||
@ -2746,8 +2813,8 @@ if TYPE_CHECKING:
|
||||
TFOpenAIGPTModel,
|
||||
TFOpenAIGPTPreTrainedModel,
|
||||
)
|
||||
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
||||
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
|
||||
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
||||
from .models.roberta import (
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFRobertaForMaskedLM,
|
||||
@ -2878,6 +2945,7 @@ if TYPE_CHECKING:
|
||||
FlaxBartForQuestionAnswering,
|
||||
FlaxBartForSequenceClassification,
|
||||
FlaxBartModel,
|
||||
FlaxBartPreTrainedModel,
|
||||
)
|
||||
from .models.bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
@ -2900,7 +2968,14 @@ if TYPE_CHECKING:
|
||||
FlaxBigBirdModel,
|
||||
FlaxBigBirdPreTrainedModel,
|
||||
)
|
||||
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
||||
from .models.clip import (
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPPreTrainedModel,
|
||||
FlaxCLIPTextModel,
|
||||
FlaxCLIPTextPreTrainedModel,
|
||||
FlaxCLIPVisionModel,
|
||||
FlaxCLIPVisionPreTrainedModel,
|
||||
)
|
||||
from .models.electra import (
|
||||
FlaxElectraForMaskedLM,
|
||||
FlaxElectraForMultipleChoice,
|
||||
@ -2911,7 +2986,7 @@ if TYPE_CHECKING:
|
||||
FlaxElectraModel,
|
||||
FlaxElectraPreTrainedModel,
|
||||
)
|
||||
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
|
||||
from .models.roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
@ -2921,8 +2996,8 @@ if TYPE_CHECKING:
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaPreTrainedModel,
|
||||
)
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||
else:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
|
@ -55,6 +55,7 @@ if is_flax_available():
|
||||
"FlaxBartForQuestionAnswering",
|
||||
"FlaxBartForSequenceClassification",
|
||||
"FlaxBartModel",
|
||||
"FlaxBartPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -85,6 +86,7 @@ if TYPE_CHECKING:
|
||||
FlaxBartForQuestionAnswering,
|
||||
FlaxBartForSequenceClassification,
|
||||
FlaxBartModel,
|
||||
FlaxBartPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -32,6 +32,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_bert_generation"] = [
|
||||
"BertGenerationDecoder",
|
||||
"BertGenerationEncoder",
|
||||
"BertGenerationPreTrainedModel",
|
||||
"load_tf_weights_in_bert_generation",
|
||||
]
|
||||
|
||||
@ -46,6 +47,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_bert_generation import (
|
||||
BertGenerationDecoder,
|
||||
BertGenerationEncoder,
|
||||
BertGenerationPreTrainedModel,
|
||||
load_tf_weights_in_bert_generation,
|
||||
)
|
||||
|
||||
|
@ -37,7 +37,11 @@ if is_torch_available():
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
|
||||
_import_structure["modeling_tf_blenderbot"] = [
|
||||
"TFBlenderbotForConditionalGeneration",
|
||||
"TFBlenderbotModel",
|
||||
"TFBlenderbotPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -54,7 +58,11 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||
from .modeling_tf_blenderbot import (
|
||||
TFBlenderbotForConditionalGeneration,
|
||||
TFBlenderbotModel,
|
||||
TFBlenderbotPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -38,6 +38,7 @@ if is_tf_available():
|
||||
_import_structure["modeling_tf_blenderbot_small"] = [
|
||||
"TFBlenderbotSmallForConditionalGeneration",
|
||||
"TFBlenderbotSmallModel",
|
||||
"TFBlenderbotSmallPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -54,7 +55,11 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
||||
from .modeling_tf_blenderbot_small import (
|
||||
TFBlenderbotSmallForConditionalGeneration,
|
||||
TFBlenderbotSmallModel,
|
||||
TFBlenderbotSmallPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -52,7 +52,9 @@ if is_flax_available():
|
||||
"FlaxCLIPModel",
|
||||
"FlaxCLIPPreTrainedModel",
|
||||
"FlaxCLIPTextModel",
|
||||
"FlaxCLIPTextPreTrainedModel",
|
||||
"FlaxCLIPVisionModel",
|
||||
"FlaxCLIPVisionPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
@ -77,7 +79,14 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
||||
from .modeling_flax_clip import (
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPPreTrainedModel,
|
||||
FlaxCLIPTextModel,
|
||||
FlaxCLIPTextPreTrainedModel,
|
||||
FlaxCLIPVisionModel,
|
||||
FlaxCLIPVisionPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
|
@ -46,6 +46,7 @@ if is_tf_available():
|
||||
"TFFlaubertForSequenceClassification",
|
||||
"TFFlaubertForTokenClassification",
|
||||
"TFFlaubertModel",
|
||||
"TFFlaubertPreTrainedModel",
|
||||
"TFFlaubertWithLMHeadModel",
|
||||
]
|
||||
|
||||
@ -74,6 +75,7 @@ if TYPE_CHECKING:
|
||||
TFFlaubertForSequenceClassification,
|
||||
TFFlaubertForTokenClassification,
|
||||
TFFlaubertModel,
|
||||
TFFlaubertPreTrainedModel,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
)
|
||||
|
||||
|
@ -41,6 +41,7 @@ if is_torch_available():
|
||||
"FunnelForSequenceClassification",
|
||||
"FunnelForTokenClassification",
|
||||
"FunnelModel",
|
||||
"FunnelPreTrainedModel",
|
||||
"load_tf_weights_in_funnel",
|
||||
]
|
||||
|
||||
@ -55,6 +56,7 @@ if is_tf_available():
|
||||
"TFFunnelForSequenceClassification",
|
||||
"TFFunnelForTokenClassification",
|
||||
"TFFunnelModel",
|
||||
"TFFunnelPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
@ -76,6 +78,7 @@ if TYPE_CHECKING:
|
||||
FunnelForSequenceClassification,
|
||||
FunnelForTokenClassification,
|
||||
FunnelModel,
|
||||
FunnelPreTrainedModel,
|
||||
load_tf_weights_in_funnel,
|
||||
)
|
||||
|
||||
@ -90,6 +93,7 @@ if TYPE_CHECKING:
|
||||
TFFunnelForSequenceClassification,
|
||||
TFFunnelForTokenClassification,
|
||||
TFFunnelModel,
|
||||
TFFunnelPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -58,7 +58,7 @@ if is_tf_available():
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
|
||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
@ -90,7 +90,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
||||
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -38,6 +38,7 @@ if is_torch_available():
|
||||
"LayoutLMForSequenceClassification",
|
||||
"LayoutLMForTokenClassification",
|
||||
"LayoutLMModel",
|
||||
"LayoutLMPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
@ -66,6 +67,7 @@ if TYPE_CHECKING:
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
LayoutLMPreTrainedModel,
|
||||
)
|
||||
if is_tf_available():
|
||||
from .modeling_tf_layoutlm import (
|
||||
|
@ -38,6 +38,7 @@ if is_torch_available():
|
||||
"LongformerForSequenceClassification",
|
||||
"LongformerForTokenClassification",
|
||||
"LongformerModel",
|
||||
"LongformerPreTrainedModel",
|
||||
"LongformerSelfAttention",
|
||||
]
|
||||
|
||||
@ -50,6 +51,7 @@ if is_tf_available():
|
||||
"TFLongformerForSequenceClassification",
|
||||
"TFLongformerForTokenClassification",
|
||||
"TFLongformerModel",
|
||||
"TFLongformerPreTrainedModel",
|
||||
"TFLongformerSelfAttention",
|
||||
]
|
||||
|
||||
@ -70,6 +72,7 @@ if TYPE_CHECKING:
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForTokenClassification,
|
||||
LongformerModel,
|
||||
LongformerPreTrainedModel,
|
||||
LongformerSelfAttention,
|
||||
)
|
||||
|
||||
@ -82,6 +85,7 @@ if TYPE_CHECKING:
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerPreTrainedModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
|
||||
|
@ -43,7 +43,7 @@ if is_torch_available():
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
|
||||
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -62,7 +62,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
|
||||
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -50,7 +50,11 @@ if is_torch_available():
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
|
||||
_import_structure["modeling_tf_mbart"] = [
|
||||
"TFMBartForConditionalGeneration",
|
||||
"TFMBartModel",
|
||||
"TFMBartPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -76,7 +80,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -36,6 +36,7 @@ if is_torch_available():
|
||||
"MegatronBertForSequenceClassification",
|
||||
"MegatronBertForTokenClassification",
|
||||
"MegatronBertModel",
|
||||
"MegatronBertPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -53,6 +54,7 @@ if TYPE_CHECKING:
|
||||
MegatronBertForSequenceClassification,
|
||||
MegatronBertForTokenClassification,
|
||||
MegatronBertModel,
|
||||
MegatronBertPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -46,7 +46,11 @@ if is_torch_available():
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
|
||||
_import_structure["modeling_tf_pegasus"] = [
|
||||
"TFPegasusForConditionalGeneration",
|
||||
"TFPegasusModel",
|
||||
"TFPegasusPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -68,7 +72,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -28,10 +28,20 @@ _import_structure = {
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
||||
_import_structure["modeling_rag"] = [
|
||||
"RagModel",
|
||||
"RagPreTrainedModel",
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
|
||||
_import_structure["modeling_tf_rag"] = [
|
||||
"TFRagModel",
|
||||
"TFRagPreTrainedModel",
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,10 +50,15 @@ if TYPE_CHECKING:
|
||||
from .tokenization_rag import RagTokenizer
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
||||
from .modeling_tf_rag import (
|
||||
TFRagModel,
|
||||
TFRagPreTrainedModel,
|
||||
TFRagSequenceForGeneration,
|
||||
TFRagTokenForGeneration,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -41,6 +41,7 @@ if is_torch_available():
|
||||
"ReformerLayer",
|
||||
"ReformerModel",
|
||||
"ReformerModelWithLMHead",
|
||||
"ReformerPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
@ -63,6 +64,7 @@ if TYPE_CHECKING:
|
||||
ReformerLayer,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -45,6 +45,7 @@ if is_torch_available():
|
||||
"RobertaForSequenceClassification",
|
||||
"RobertaForTokenClassification",
|
||||
"RobertaModel",
|
||||
"RobertaPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
@ -89,6 +90,7 @@ if TYPE_CHECKING:
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
RobertaModel,
|
||||
RobertaPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -33,6 +33,7 @@ if is_torch_available():
|
||||
"TapasForQuestionAnswering",
|
||||
"TapasForSequenceClassification",
|
||||
"TapasModel",
|
||||
"TapasPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
@ -47,6 +48,7 @@ if TYPE_CHECKING:
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
TapasPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -37,7 +37,11 @@ if is_torch_available():
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"]
|
||||
_import_structure["modeling_flax_vit"] = [
|
||||
"FlaxViTForImageClassification",
|
||||
"FlaxViTModel",
|
||||
"FlaxViTPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
@ -54,7 +58,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
|
@ -244,6 +244,15 @@ class FlaxBartModel:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBartPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBertForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -412,6 +421,15 @@ class FlaxCLIPTextModel:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxCLIPTextPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxCLIPVisionModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -421,6 +439,15 @@ class FlaxCLIPVisionModel:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxCLIPVisionPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxElectraForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -507,6 +534,15 @@ class FlaxGPT2Model:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxGPT2PreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxRobertaForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -588,6 +624,15 @@ class FlaxT5Model:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxT5PreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxViTForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
@ -600,3 +645,12 @@ class FlaxViTModel:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxViTPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
@ -692,6 +692,15 @@ class BertGenerationEncoder:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BertGenerationPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def load_tf_weights_in_bert_generation(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_bert_generation, ["torch"])
|
||||
|
||||
@ -833,6 +842,15 @@ class BigBirdPegasusModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BigBirdPegasusPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -863,6 +881,15 @@ class BlenderbotModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlenderbotPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -893,6 +920,15 @@ class BlenderbotSmallModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlenderbotSmallPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -1610,6 +1646,15 @@ class FunnelModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FunnelPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def load_tf_weights_in_funnel(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_funnel, ["torch"])
|
||||
|
||||
@ -1840,6 +1885,15 @@ class LayoutLMModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -1879,6 +1933,15 @@ class LEDModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LEDPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -1936,6 +1999,15 @@ class LongformerModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongformerPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LongformerSelfAttention:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -2045,6 +2117,15 @@ class M2M100Model:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class M2M100PreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MarianForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -2117,6 +2198,15 @@ class MBartModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MBartPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -2193,6 +2283,15 @@ class MegatronBertModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MegatronBertPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MMBTForClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -2474,6 +2573,15 @@ class PegasusModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PegasusPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -2532,6 +2640,15 @@ class RagModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class RagPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class RagSequenceForGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ReformerPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -2687,6 +2813,15 @@ class RobertaModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class RobertaPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -2792,6 +2927,15 @@ class Speech2TextModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Speech2TextPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -2945,6 +3089,15 @@ class TapasModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TapasPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -431,6 +431,15 @@ class TFBlenderbotModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFBlenderbotPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFBlenderbotSmallForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
@ -449,6 +458,15 @@ class TFBlenderbotSmallModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFBlenderbotSmallPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -845,6 +863,15 @@ class TFFlaubertModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFFlaubertPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFFlaubertWithLMHeadModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
@ -925,6 +952,15 @@ class TFFunnelModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFFunnelPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -1062,6 +1098,15 @@ class TFLongformerModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFLongformerPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFLongformerSelfAttention:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
@ -1121,6 +1166,15 @@ class TFMarianMTModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFMarianPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFMBartForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
@ -1139,6 +1193,15 @@ class TFMBartModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFMBartPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -1389,6 +1452,15 @@ class TFPegasusModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFPegasusPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFRagModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
@ -1398,6 +1470,15 @@ class TFRagModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFRagPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFRagSequenceForGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
@ -30,3 +30,12 @@ class DetrModel:
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["timm", "vision"])
|
||||
|
||||
|
||||
class DetrPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["timm", "vision"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["timm", "vision"])
|
||||
|
@ -52,6 +52,7 @@
|
||||
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||
"{{cookiecutter.camelcase_modelname}}Model",
|
||||
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||
]
|
||||
)
|
||||
{% endif -%}
|
||||
@ -120,6 +121,7 @@
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_TESTS = "tests"
|
||||
PATH_TO_DOC = "docs/source"
|
||||
|
||||
# Update this list with models that are supposed to be private.
|
||||
PRIVATE_MODELS = [
|
||||
"DPRSpanPredictor",
|
||||
"T5Stack",
|
||||
"TFDPRSpanPredictor",
|
||||
]
|
||||
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = [
|
||||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for not tested
|
||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||
@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [
|
||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"T5Stack", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = [
|
||||
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for model xxx mapping
|
||||
"CLIPTextModel",
|
||||
"CLIPVisionModel",
|
||||
@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"FlaxCLIPVisionModel",
|
||||
"DetrForSegmentation",
|
||||
"DPRReader",
|
||||
"DPRSpanPredictor",
|
||||
"FlaubertForQuestionAnswering",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"LukeForEntityClassification",
|
||||
@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"RagModel",
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
"T5Stack",
|
||||
"TFDPRReader",
|
||||
"TFDPRSpanPredictor",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFRagModel",
|
||||
@ -173,12 +174,12 @@ def get_model_modules():
|
||||
return modules
|
||||
|
||||
|
||||
def get_models(module):
|
||||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
if "Pretrained" in attr_name or "PreTrained" in attr_name:
|
||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||
continue
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
|
||||
@ -186,6 +187,36 @@ def get_models(module):
|
||||
return models
|
||||
|
||||
|
||||
def is_a_private_model(model):
|
||||
"""Returns True if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
|
||||
# Wrapper, Encoder and Decoder are all privates
|
||||
if model.endswith("Wrapper"):
|
||||
return True
|
||||
if model.endswith("Encoder"):
|
||||
return True
|
||||
if model.endswith("Decoder"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_models_are_in_init():
|
||||
"""Checks all models defined in the library are in the main init."""
|
||||
models_not_in_init = []
|
||||
dir_transformers = dir(transformers)
|
||||
for module in get_model_modules():
|
||||
models_not_in_init += [
|
||||
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
||||
]
|
||||
|
||||
# Remove private models
|
||||
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
|
||||
if len(models_not_in_init) > 0:
|
||||
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
|
||||
|
||||
|
||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||
# nested list _ignore_files of this function.
|
||||
def get_model_test_files():
|
||||
@ -229,6 +260,7 @@ def find_tested_models(test_file):
|
||||
|
||||
def check_models_are_tested(module, test_file):
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
# XxxPreTrainedModel are not tested
|
||||
defined_models = get_models(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
if tested_models is None:
|
||||
@ -515,6 +547,8 @@ def check_all_objects_are_documented():
|
||||
|
||||
def check_repo_quality():
|
||||
"""Check all models are properly tested and documented."""
|
||||
print("Checking all models are public.")
|
||||
check_models_are_in_init()
|
||||
print("Checking all models are properly tested.")
|
||||
check_all_decorator_order()
|
||||
check_all_models_are_tested()
|
||||
|
Loading…
Reference in New Issue
Block a user