Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
This commit is contained in:
Sylvain Gugger 2021-06-23 10:40:54 -04:00 committed by GitHub
parent 53c60babe4
commit 9eda6b52e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 532 additions and 51 deletions

View File

@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available():
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]
)
else:
@ -570,6 +571,7 @@ if is_torch_available():
[
"BertGenerationDecoder",
"BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation",
]
)
@ -597,6 +599,7 @@ if is_torch_available():
"BigBirdPegasusForQuestionAnswering",
"BigBirdPegasusForSequenceClassification",
"BigBirdPegasusModel",
"BigBirdPegasusPreTrainedModel",
]
)
_import_structure["models.blenderbot"].extend(
@ -605,6 +608,7 @@ if is_torch_available():
"BlenderbotForCausalLM",
"BlenderbotForConditionalGeneration",
"BlenderbotModel",
"BlenderbotPreTrainedModel",
]
)
_import_structure["models.blenderbot_small"].extend(
@ -613,6 +617,7 @@ if is_torch_available():
"BlenderbotSmallForCausalLM",
"BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel",
"BlenderbotSmallPreTrainedModel",
]
)
_import_structure["models.camembert"].extend(
@ -754,6 +759,7 @@ if is_torch_available():
"FunnelForSequenceClassification",
"FunnelForTokenClassification",
"FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel",
]
)
@ -805,6 +811,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
)
_import_structure["models.led"].extend(
@ -814,6 +821,7 @@ if is_torch_available():
"LEDForQuestionAnswering",
"LEDForSequenceClassification",
"LEDModel",
"LEDPreTrainedModel",
]
)
_import_structure["models.longformer"].extend(
@ -825,6 +833,7 @@ if is_torch_available():
"LongformerForSequenceClassification",
"LongformerForTokenClassification",
"LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention",
]
)
@ -854,6 +863,7 @@ if is_torch_available():
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration",
"M2M100Model",
"M2M100PreTrainedModel",
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
@ -864,6 +874,7 @@ if is_torch_available():
"MBartForQuestionAnswering",
"MBartForSequenceClassification",
"MBartModel",
"MBartPreTrainedModel",
]
)
_import_structure["models.megatron_bert"].extend(
@ -878,6 +889,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification",
"MegatronBertModel",
"MegatronBertPreTrainedModel",
]
)
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
@ -923,7 +935,7 @@ if is_torch_available():
]
)
_import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
)
_import_structure["models.prophetnet"].extend(
[
@ -936,7 +948,9 @@ if is_torch_available():
"ProphetNetPreTrainedModel",
]
)
_import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"])
_import_structure["models.rag"].extend(
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
)
_import_structure["models.reformer"].extend(
[
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -947,6 +961,7 @@ if is_torch_available():
"ReformerLayer",
"ReformerModel",
"ReformerModelWithLMHead",
"ReformerPreTrainedModel",
]
)
_import_structure["models.retribert"].extend(
@ -962,6 +977,7 @@ if is_torch_available():
"RobertaForSequenceClassification",
"RobertaForTokenClassification",
"RobertaModel",
"RobertaPreTrainedModel",
]
)
_import_structure["models.roformer"].extend(
@ -984,6 +1000,7 @@ if is_torch_available():
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
"Speech2TextModel",
"Speech2TextPreTrainedModel",
]
)
_import_structure["models.squeezebert"].extend(
@ -1016,6 +1033,7 @@ if is_torch_available():
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
]
)
_import_structure["models.transfo_xl"].extend(
@ -1197,9 +1215,11 @@ if is_tf_available():
"TFBertPreTrainedModel",
]
)
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
_import_structure["models.blenderbot"].extend(
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
)
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
)
_import_structure["models.camembert"].extend(
[
@ -1281,6 +1301,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification",
"TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel",
]
)
@ -1295,6 +1316,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification",
"TFFunnelModel",
"TFFunnelPreTrainedModel",
]
)
_import_structure["models.gpt2"].extend(
@ -1329,6 +1351,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification",
"TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention",
]
)
@ -1342,8 +1365,10 @@ if is_tf_available():
"TFLxmertVisualFeatureEncoder",
]
)
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
_import_structure["models.mbart"].extend(
["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
)
_import_structure["models.mobilebert"].extend(
[
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -1384,10 +1409,13 @@ if is_tf_available():
"TFOpenAIGPTPreTrainedModel",
]
)
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
_import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
_import_structure["models.rag"].extend(
[
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
@ -1538,6 +1566,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
"FlaxBartPreTrainedModel",
]
)
_import_structure["models.bert"].extend(
@ -1570,7 +1599,9 @@ if is_flax_available():
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
@ -1585,7 +1616,7 @@ if is_flax_available():
"FlaxElectraPreTrainedModel",
]
)
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"])
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
@ -1597,8 +1628,8 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel",
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
else:
from .utils import dummy_flax_objects
@ -1949,6 +1980,7 @@ if TYPE_CHECKING:
DetrForObjectDetection,
DetrForSegmentation,
DetrModel,
DetrPreTrainedModel,
)
else:
from .utils.dummy_timm_objects import *
@ -2074,6 +2106,7 @@ if TYPE_CHECKING:
from .models.bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation,
)
from .models.big_bird import (
@ -2097,18 +2130,21 @@ if TYPE_CHECKING:
BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel,
BigBirdPegasusPreTrainedModel,
)
from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM,
BlenderbotForConditionalGeneration,
BlenderbotModel,
BlenderbotPreTrainedModel,
)
from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
)
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2226,6 +2262,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification,
FunnelForTokenClassification,
FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel,
)
from .models.gpt2 import (
@ -2267,6 +2304,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
LayoutLMPreTrainedModel,
)
from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2274,6 +2312,7 @@ if TYPE_CHECKING:
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
LEDPreTrainedModel,
)
from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2283,6 +2322,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention,
)
from .models.luke import (
@ -2302,7 +2342,12 @@ if TYPE_CHECKING:
LxmertVisualFeatureEncoder,
LxmertXLayer,
)
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model
from .models.m2m_100 import (
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
M2M100ForConditionalGeneration,
M2M100Model,
M2M100PreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.mbart import (
MBartForCausalLM,
@ -2310,6 +2355,7 @@ if TYPE_CHECKING:
MBartForQuestionAnswering,
MBartForSequenceClassification,
MBartModel,
MBartPreTrainedModel,
)
from .models.megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2322,6 +2368,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
MegatronBertPreTrainedModel,
)
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import (
@ -2359,7 +2406,12 @@ if TYPE_CHECKING:
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
from .models.pegasus import (
PegasusForCausalLM,
PegasusForConditionalGeneration,
PegasusModel,
PegasusPreTrainedModel,
)
from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,
@ -2369,7 +2421,7 @@ if TYPE_CHECKING:
ProphetNetModel,
ProphetNetPreTrainedModel,
)
from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
@ -2379,6 +2431,7 @@ if TYPE_CHECKING:
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
ReformerPreTrainedModel,
)
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
from .models.roberta import (
@ -2390,6 +2443,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
RobertaPreTrainedModel,
)
from .models.roformer import (
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2408,6 +2462,7 @@ if TYPE_CHECKING:
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration,
Speech2TextModel,
Speech2TextPreTrainedModel,
)
from .models.squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2434,6 +2489,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
)
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2600,8 +2656,16 @@ if TYPE_CHECKING:
TFBertModel,
TFBertPreTrainedModel,
)
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .models.blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
from .models.blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
@ -2669,6 +2733,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel,
)
from .models.funnel import (
@ -2681,6 +2746,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
TFFunnelPreTrainedModel,
)
from .models.gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -2700,6 +2766,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention,
)
from .models.lxmert import (
@ -2710,8 +2777,8 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder,
)
from .models.marian import TFMarianModel, TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
@ -2746,8 +2813,8 @@ if TYPE_CHECKING:
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
@ -2878,6 +2945,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
FlaxBartPreTrainedModel,
)
from .models.bert import (
FlaxBertForMaskedLM,
@ -2900,7 +2968,14 @@ if TYPE_CHECKING:
FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel,
)
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
@ -2911,7 +2986,7 @@ if TYPE_CHECKING:
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
@ -2921,8 +2996,8 @@ if TYPE_CHECKING:
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.

View File

@ -55,6 +55,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
"FlaxBartPreTrainedModel",
]
if TYPE_CHECKING:
@ -85,6 +86,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
FlaxBartPreTrainedModel,
)
else:

View File

@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_bert_generation"] = [
"BertGenerationDecoder",
"BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation",
]
@ -46,6 +47,7 @@ if TYPE_CHECKING:
from .modeling_bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation,
)

View File

@ -37,7 +37,11 @@ if is_torch_available():
if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
_import_structure["modeling_tf_blenderbot"] = [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
if TYPE_CHECKING:
@ -54,7 +58,11 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .modeling_tf_blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
else:
import importlib

View File

@ -38,6 +38,7 @@ if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
]
if TYPE_CHECKING:
@ -54,7 +55,11 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
else:
import importlib

View File

@ -52,7 +52,9 @@ if is_flax_available():
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
@ -77,7 +79,14 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
else:

View File

@ -46,6 +46,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification",
"TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel",
]
@ -74,6 +75,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel,
)

View File

@ -41,6 +41,7 @@ if is_torch_available():
"FunnelForSequenceClassification",
"FunnelForTokenClassification",
"FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel",
]
@ -55,6 +56,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification",
"TFFunnelModel",
"TFFunnelPreTrainedModel",
]
@ -76,6 +78,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification,
FunnelForTokenClassification,
FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel,
)
@ -90,6 +93,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
TFFunnelPreTrainedModel,
)
else:

View File

@ -58,7 +58,7 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
@ -90,7 +90,7 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else:
import importlib

View File

@ -38,6 +38,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
if is_tf_available():
@ -66,6 +67,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
LayoutLMPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_layoutlm import (

View File

@ -38,6 +38,7 @@ if is_torch_available():
"LongformerForSequenceClassification",
"LongformerForTokenClassification",
"LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention",
]
@ -50,6 +51,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification",
"TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention",
]
@ -70,6 +72,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention,
)
@ -82,6 +85,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention,
)

View File

@ -43,7 +43,7 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
if TYPE_CHECKING:
@ -62,7 +62,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
else:
import importlib

View File

@ -50,7 +50,11 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
_import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
if TYPE_CHECKING:
@ -76,7 +80,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
else:
import importlib

View File

@ -36,6 +36,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification",
"MegatronBertModel",
"MegatronBertPreTrainedModel",
]
if TYPE_CHECKING:
@ -53,6 +54,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
MegatronBertPreTrainedModel,
)
else:

View File

@ -46,7 +46,11 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
_import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
if TYPE_CHECKING:
@ -68,7 +72,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
else:
import importlib

View File

@ -28,10 +28,20 @@ _import_structure = {
}
if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
_import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
_import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING:
@ -40,10 +50,15 @@ if TYPE_CHECKING:
from .tokenization_rag import RagTokenizer
if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else:
import importlib

View File

@ -41,6 +41,7 @@ if is_torch_available():
"ReformerLayer",
"ReformerModel",
"ReformerModelWithLMHead",
"ReformerPreTrainedModel",
]
@ -63,6 +64,7 @@ if TYPE_CHECKING:
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
ReformerPreTrainedModel,
)
else:

View File

@ -45,6 +45,7 @@ if is_torch_available():
"RobertaForSequenceClassification",
"RobertaForTokenClassification",
"RobertaModel",
"RobertaPreTrainedModel",
]
if is_tf_available():
@ -89,6 +90,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
RobertaPreTrainedModel,
)
if is_tf_available():

View File

@ -33,6 +33,7 @@ if is_torch_available():
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
]
@ -47,6 +48,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
)
else:

View File

@ -37,7 +37,11 @@ if is_torch_available():
if is_flax_available():
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"]
_import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
"FlaxViTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
@ -54,7 +58,7 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else:

View File

@ -244,6 +244,15 @@ class FlaxBartModel:
requires_backends(cls, ["flax"])
class FlaxBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -412,6 +421,15 @@ class FlaxCLIPTextModel:
requires_backends(cls, ["flax"])
class FlaxCLIPTextPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxCLIPVisionModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -421,6 +439,15 @@ class FlaxCLIPVisionModel:
requires_backends(cls, ["flax"])
class FlaxCLIPVisionPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxElectraForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -507,6 +534,15 @@ class FlaxGPT2Model:
requires_backends(cls, ["flax"])
class FlaxGPT2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxRobertaForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -588,6 +624,15 @@ class FlaxT5Model:
requires_backends(cls, ["flax"])
class FlaxT5PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@ -600,3 +645,12 @@ class FlaxViTModel:
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

View File

@ -692,6 +692,15 @@ class BertGenerationEncoder:
requires_backends(self, ["torch"])
class BertGenerationPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def load_tf_weights_in_bert_generation(*args, **kwargs):
requires_backends(load_tf_weights_in_bert_generation, ["torch"])
@ -833,6 +842,15 @@ class BigBirdPegasusModel:
requires_backends(cls, ["torch"])
class BigBirdPegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -863,6 +881,15 @@ class BlenderbotModel:
requires_backends(cls, ["torch"])
class BlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -893,6 +920,15 @@ class BlenderbotSmallModel:
requires_backends(cls, ["torch"])
class BlenderbotSmallPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -1610,6 +1646,15 @@ class FunnelModel:
requires_backends(cls, ["torch"])
class FunnelPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def load_tf_weights_in_funnel(*args, **kwargs):
requires_backends(load_tf_weights_in_funnel, ["torch"])
@ -1840,6 +1885,15 @@ class LayoutLMModel:
requires_backends(cls, ["torch"])
class LayoutLMPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -1879,6 +1933,15 @@ class LEDModel:
requires_backends(cls, ["torch"])
class LEDPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -1936,6 +1999,15 @@ class LongformerModel:
requires_backends(cls, ["torch"])
class LongformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LongformerSelfAttention:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@ -2045,6 +2117,15 @@ class M2M100Model:
requires_backends(cls, ["torch"])
class M2M100PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MarianForCausalLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@ -2117,6 +2198,15 @@ class MBartModel:
requires_backends(cls, ["torch"])
class MBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -2193,6 +2283,15 @@ class MegatronBertModel:
requires_backends(cls, ["torch"])
class MegatronBertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MMBTForClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@ -2474,6 +2573,15 @@ class PegasusModel:
requires_backends(cls, ["torch"])
class PegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -2532,6 +2640,15 @@ class RagModel:
requires_backends(cls, ["torch"])
class RagPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class RagSequenceForGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead:
requires_backends(cls, ["torch"])
class ReformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -2687,6 +2813,15 @@ class RobertaModel:
requires_backends(cls, ["torch"])
class RobertaPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -2792,6 +2927,15 @@ class Speech2TextModel:
requires_backends(cls, ["torch"])
class Speech2TextPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -2945,6 +3089,15 @@ class TapasModel:
requires_backends(cls, ["torch"])
class TapasPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -431,6 +431,15 @@ class TFBlenderbotModel:
requires_backends(cls, ["tf"])
class TFBlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFBlenderbotSmallForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@ -449,6 +458,15 @@ class TFBlenderbotSmallModel:
requires_backends(cls, ["tf"])
class TFBlenderbotSmallPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -845,6 +863,15 @@ class TFFlaubertModel:
requires_backends(cls, ["tf"])
class TFFlaubertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFFlaubertWithLMHeadModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@ -925,6 +952,15 @@ class TFFunnelModel:
requires_backends(cls, ["tf"])
class TFFunnelPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -1062,6 +1098,15 @@ class TFLongformerModel:
requires_backends(cls, ["tf"])
class TFLongformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFLongformerSelfAttention:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@ -1121,6 +1166,15 @@ class TFMarianMTModel:
requires_backends(cls, ["tf"])
class TFMarianPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFMBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@ -1139,6 +1193,15 @@ class TFMBartModel:
requires_backends(cls, ["tf"])
class TFMBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -1389,6 +1452,15 @@ class TFPegasusModel:
requires_backends(cls, ["tf"])
class TFPegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRagModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@ -1398,6 +1470,15 @@ class TFRagModel:
requires_backends(cls, ["tf"])
class TFRagPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRagSequenceForGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])

View File

@ -30,3 +30,12 @@ class DetrModel:
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])
class DetrPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])

View File

@ -52,6 +52,7 @@
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
)
{% endif -%}
@ -120,6 +121,7 @@
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif -%}
# End.

View File

@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"T5Stack",
"TFDPRSpanPredictor",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [
"PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix
"SeparableConv1D", # Building part of bigger (tested) model.
@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"CLIPTextModel",
"CLIPVisionModel",
@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"FlaxCLIPVisionModel",
"DetrForSegmentation",
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"GPT2DoubleHeadsModel",
"LukeForEntityClassification",
@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFDPRReader",
"TFDPRSpanPredictor",
"TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
@ -173,12 +174,12 @@ def get_model_modules():
return modules
def get_models(module):
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name:
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
@ -186,6 +187,36 @@ def get_models(module):
return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
return True
if model.endswith("Decoder"):
return True
return False
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
]
# Remove private models
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
if len(models_not_in_init) > 0:
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
@ -229,6 +260,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file."""
# XxxPreTrainedModel are not tested
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
if tested_models is None:
@ -515,6 +547,8 @@ def check_all_objects_are_documented():
def check_repo_quality():
"""Check all models are properly tested and documented."""
print("Checking all models are public.")
check_models_are_in_init()
print("Checking all models are properly tested.")
check_all_decorator_order()
check_all_models_are_tested()