mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Sort init import (#10801)
* Initial script * Add script to properly sort imports in init. * Add to the CI * Update utils/custom_init_isort.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Separate scripts that change content from quality * Move class_mapping_update to style_checks Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
1438c487df
commit
21e86f99e6
@ -383,6 +383,7 @@ jobs:
|
||||
- '~/.cache/pip'
|
||||
- run: black --check examples tests src utils
|
||||
- run: isort --check-only examples tests src utils
|
||||
- run: python utils/custom_init_isort.py --check_only
|
||||
- run: flake8 examples tests src utils
|
||||
- run: python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
|
||||
- run: python utils/check_copies.py
|
||||
|
18
Makefile
18
Makefile
@ -21,32 +21,36 @@ deps_table_update:
|
||||
|
||||
# Check that source code meets quality standards
|
||||
|
||||
extra_quality_checks: deps_table_update
|
||||
extra_quality_checks:
|
||||
python utils/check_copies.py
|
||||
python utils/check_table.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/style_doc.py src/transformers docs/source --max_len 119
|
||||
python utils/class_mapping_update.py
|
||||
|
||||
# this target runs checks on all files
|
||||
quality:
|
||||
black --check $(check_dirs)
|
||||
isort --check-only $(check_dirs)
|
||||
python utils/custom_init_isort.py --check_only
|
||||
flake8 $(check_dirs)
|
||||
python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
|
||||
${MAKE} extra_quality_checks
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
|
||||
style: deps_table_update
|
||||
extra_style_checks: deps_table_update
|
||||
python utils/custom_init_isort.py
|
||||
python utils/style_doc.py src/transformers docs/source --max_len 119
|
||||
python utils/class_mapping_update.py
|
||||
|
||||
# this target runs checks on all files
|
||||
style:
|
||||
black $(check_dirs)
|
||||
isort $(check_dirs)
|
||||
python utils/style_doc.py src/transformers docs/source --max_len 119
|
||||
${MAKE} extra_style_checks
|
||||
|
||||
# Super fast fix and check target that only works on relevant modified files since the branch was made
|
||||
|
||||
fixup: modified_only_fixup extra_quality_checks
|
||||
fixup: modified_only_fixup extra_style_checks extra_quality_checks
|
||||
|
||||
# Make marked copies of snippets of codes conform to the original
|
||||
|
||||
|
@ -78,6 +78,7 @@ _import_structure = {
|
||||
"xnli_processors",
|
||||
"xnli_tasks_num_labels",
|
||||
],
|
||||
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
|
||||
"file_utils": [
|
||||
"CONFIG_NAME",
|
||||
"MODEL_CARD_NAME",
|
||||
@ -124,23 +125,8 @@ _import_structure = {
|
||||
"load_tf2_model_in_pytorch_model",
|
||||
"load_tf2_weights_in_pytorch_model",
|
||||
],
|
||||
"models": [],
|
||||
# Models
|
||||
"models.wav2vec2": [
|
||||
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Wav2Vec2Config",
|
||||
"Wav2Vec2CTCTokenizer",
|
||||
"Wav2Vec2Tokenizer",
|
||||
"Wav2Vec2FeatureExtractor",
|
||||
"Wav2Vec2Processor",
|
||||
],
|
||||
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
||||
"models.speech_to_text": [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Speech2TextConfig",
|
||||
"Speech2TextFeatureExtractor",
|
||||
],
|
||||
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
||||
"models": [],
|
||||
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
||||
"models.auto": [
|
||||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
@ -169,6 +155,7 @@ _import_structure = {
|
||||
"BlenderbotSmallTokenizer",
|
||||
],
|
||||
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
|
||||
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
||||
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
|
||||
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"],
|
||||
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
|
||||
@ -193,6 +180,7 @@ _import_structure = {
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
|
||||
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
||||
"models.marian": ["MarianConfig"],
|
||||
"models.mbart": ["MBartConfig"],
|
||||
"models.mmbt": ["MMBTConfig"],
|
||||
@ -207,6 +195,11 @@ _import_structure = {
|
||||
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
|
||||
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
|
||||
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
|
||||
"models.speech_to_text": [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Speech2TextConfig",
|
||||
"Speech2TextFeatureExtractor",
|
||||
],
|
||||
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
|
||||
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
||||
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
|
||||
@ -216,6 +209,14 @@ _import_structure = {
|
||||
"TransfoXLCorpus",
|
||||
"TransfoXLTokenizer",
|
||||
],
|
||||
"models.wav2vec2": [
|
||||
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Wav2Vec2Config",
|
||||
"Wav2Vec2CTCTokenizer",
|
||||
"Wav2Vec2FeatureExtractor",
|
||||
"Wav2Vec2Processor",
|
||||
"Wav2Vec2Tokenizer",
|
||||
],
|
||||
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||
@ -251,7 +252,6 @@ _import_structure = {
|
||||
"SpecialTokensMixin",
|
||||
"TokenSpan",
|
||||
],
|
||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
|
||||
"trainer_callback": [
|
||||
"DefaultFlowCallback",
|
||||
"EarlyStoppingCallback",
|
||||
@ -383,54 +383,14 @@ if is_torch_available():
|
||||
"TopPLogitsWarper",
|
||||
]
|
||||
_import_structure["generation_stopping_criteria"] = [
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
]
|
||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
||||
# PyTorch models structure
|
||||
|
||||
_import_structure["models.speech_to_text"].extend(
|
||||
[
|
||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Speech2TextForConditionalGeneration",
|
||||
"Speech2TextModel",
|
||||
]
|
||||
)
|
||||
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.m2m_100"].extend(
|
||||
[
|
||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"M2M100ForConditionalGeneration",
|
||||
"M2M100Model",
|
||||
]
|
||||
)
|
||||
|
||||
_import_structure["models.convbert"].extend(
|
||||
[
|
||||
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"ConvBertForMaskedLM",
|
||||
"ConvBertForMultipleChoice",
|
||||
"ConvBertForQuestionAnswering",
|
||||
"ConvBertForSequenceClassification",
|
||||
"ConvBertForTokenClassification",
|
||||
"ConvBertLayer",
|
||||
"ConvBertModel",
|
||||
"ConvBertPreTrainedModel",
|
||||
"load_tf_weights_in_convbert",
|
||||
]
|
||||
)
|
||||
_import_structure["models.albert"].extend(
|
||||
[
|
||||
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -512,17 +472,17 @@ if is_torch_available():
|
||||
_import_structure["models.blenderbot"].extend(
|
||||
[
|
||||
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BlenderbotForCausalLM",
|
||||
"BlenderbotForConditionalGeneration",
|
||||
"BlenderbotModel",
|
||||
"BlenderbotForCausalLM",
|
||||
]
|
||||
)
|
||||
_import_structure["models.blenderbot_small"].extend(
|
||||
[
|
||||
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BlenderbotSmallForCausalLM",
|
||||
"BlenderbotSmallForConditionalGeneration",
|
||||
"BlenderbotSmallModel",
|
||||
"BlenderbotSmallForCausalLM",
|
||||
]
|
||||
)
|
||||
_import_structure["models.camembert"].extend(
|
||||
@ -537,6 +497,20 @@ if is_torch_available():
|
||||
"CamembertModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.convbert"].extend(
|
||||
[
|
||||
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"ConvBertForMaskedLM",
|
||||
"ConvBertForMultipleChoice",
|
||||
"ConvBertForQuestionAnswering",
|
||||
"ConvBertForSequenceClassification",
|
||||
"ConvBertForTokenClassification",
|
||||
"ConvBertLayer",
|
||||
"ConvBertModel",
|
||||
"ConvBertPreTrainedModel",
|
||||
"load_tf_weights_in_convbert",
|
||||
]
|
||||
)
|
||||
_import_structure["models.ctrl"].extend(
|
||||
[
|
||||
"CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -549,23 +523,23 @@ if is_torch_available():
|
||||
_import_structure["models.deberta"].extend(
|
||||
[
|
||||
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DebertaForSequenceClassification",
|
||||
"DebertaModel",
|
||||
"DebertaForMaskedLM",
|
||||
"DebertaPreTrainedModel",
|
||||
"DebertaForTokenClassification",
|
||||
"DebertaForQuestionAnswering",
|
||||
"DebertaForSequenceClassification",
|
||||
"DebertaForTokenClassification",
|
||||
"DebertaModel",
|
||||
"DebertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.deberta_v2"].extend(
|
||||
[
|
||||
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DebertaV2ForSequenceClassification",
|
||||
"DebertaV2Model",
|
||||
"DebertaV2ForMaskedLM",
|
||||
"DebertaV2PreTrainedModel",
|
||||
"DebertaV2ForTokenClassification",
|
||||
"DebertaV2ForQuestionAnswering",
|
||||
"DebertaV2ForSequenceClassification",
|
||||
"DebertaV2ForTokenClassification",
|
||||
"DebertaV2Model",
|
||||
"DebertaV2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.distilbert"].extend(
|
||||
@ -699,7 +673,14 @@ if is_torch_available():
|
||||
"LxmertXLayer",
|
||||
]
|
||||
)
|
||||
_import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"])
|
||||
_import_structure["models.m2m_100"].extend(
|
||||
[
|
||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"M2M100ForConditionalGeneration",
|
||||
"M2M100Model",
|
||||
]
|
||||
)
|
||||
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
|
||||
_import_structure["models.mbart"].extend(
|
||||
[
|
||||
"MBartForCausalLM",
|
||||
@ -752,7 +733,7 @@ if is_torch_available():
|
||||
]
|
||||
)
|
||||
_import_structure["models.pegasus"].extend(
|
||||
["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"]
|
||||
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
|
||||
)
|
||||
_import_structure["models.prophetnet"].extend(
|
||||
[
|
||||
@ -793,6 +774,13 @@ if is_torch_available():
|
||||
"RobertaModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.speech_to_text"].extend(
|
||||
[
|
||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Speech2TextForConditionalGeneration",
|
||||
"Speech2TextModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.squeezebert"].extend(
|
||||
[
|
||||
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -836,6 +824,15 @@ if is_torch_available():
|
||||
"load_tf_weights_in_transfo_xl",
|
||||
]
|
||||
)
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.xlm"].extend(
|
||||
[
|
||||
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -916,20 +913,6 @@ if is_tf_available():
|
||||
"shape_list",
|
||||
]
|
||||
# TensorFlow models structure
|
||||
|
||||
_import_structure["models.convbert"].extend(
|
||||
[
|
||||
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFConvBertForMaskedLM",
|
||||
"TFConvBertForMultipleChoice",
|
||||
"TFConvBertForQuestionAnswering",
|
||||
"TFConvBertForSequenceClassification",
|
||||
"TFConvBertForTokenClassification",
|
||||
"TFConvBertLayer",
|
||||
"TFConvBertModel",
|
||||
"TFConvBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.albert"].extend(
|
||||
[
|
||||
"TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1002,6 +985,19 @@ if is_tf_available():
|
||||
"TFCamembertModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.convbert"].extend(
|
||||
[
|
||||
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFConvBertForMaskedLM",
|
||||
"TFConvBertForMultipleChoice",
|
||||
"TFConvBertForQuestionAnswering",
|
||||
"TFConvBertForSequenceClassification",
|
||||
"TFConvBertForTokenClassification",
|
||||
"TFConvBertLayer",
|
||||
"TFConvBertModel",
|
||||
"TFConvBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.ctrl"].extend(
|
||||
[
|
||||
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1108,7 +1104,7 @@ if is_tf_available():
|
||||
"TFLxmertVisualFeatureEncoder",
|
||||
]
|
||||
)
|
||||
_import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
|
||||
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
|
||||
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
|
||||
_import_structure["models.mobilebert"].extend(
|
||||
[
|
||||
@ -2170,7 +2166,7 @@ if TYPE_CHECKING:
|
||||
TFLxmertPreTrainedModel,
|
||||
TFLxmertVisualFeatureEncoder,
|
||||
)
|
||||
from .models.marian import TFMarian, TFMarianMTModel
|
||||
from .models.marian import TFMarianModel, TFMarianMTModel
|
||||
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||
from .models.mobilebert import (
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -29,10 +29,10 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_blenderbot"] = [
|
||||
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BlenderbotForCausalLM",
|
||||
"BlenderbotForConditionalGeneration",
|
||||
"BlenderbotModel",
|
||||
"BlenderbotPreTrainedModel",
|
||||
"BlenderbotForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
|
@ -28,10 +28,10 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_blenderbot_small"] = [
|
||||
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BlenderbotSmallForCausalLM",
|
||||
"BlenderbotSmallForConditionalGeneration",
|
||||
"BlenderbotSmallModel",
|
||||
"BlenderbotSmallPreTrainedModel",
|
||||
"BlenderbotSmallForCausalLM",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -29,12 +29,12 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_deberta"] = [
|
||||
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DebertaForSequenceClassification",
|
||||
"DebertaModel",
|
||||
"DebertaForMaskedLM",
|
||||
"DebertaPreTrainedModel",
|
||||
"DebertaForTokenClassification",
|
||||
"DebertaForQuestionAnswering",
|
||||
"DebertaForSequenceClassification",
|
||||
"DebertaForTokenClassification",
|
||||
"DebertaModel",
|
||||
"DebertaPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
|
@ -29,12 +29,12 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_deberta_v2"] = [
|
||||
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DebertaV2ForSequenceClassification",
|
||||
"DebertaV2Model",
|
||||
"DebertaV2ForMaskedLM",
|
||||
"DebertaV2PreTrainedModel",
|
||||
"DebertaV2ForTokenClassification",
|
||||
"DebertaV2ForQuestionAnswering",
|
||||
"DebertaV2ForSequenceClassification",
|
||||
"DebertaV2ForTokenClassification",
|
||||
"DebertaV2Model",
|
||||
"DebertaV2PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
|
@ -28,13 +28,13 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_ibert"] = [
|
||||
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"IBertPreTrainedModel",
|
||||
"IBertForMaskedLM",
|
||||
"IBertForMultipleChoice",
|
||||
"IBertForQuestionAnswering",
|
||||
"IBertForSequenceClassification",
|
||||
"IBertForTokenClassification",
|
||||
"IBertModel",
|
||||
"IBertPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -36,14 +36,14 @@ if is_sentencepiece_available():
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_marian"] = [
|
||||
"MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"MarianForCausalLM",
|
||||
"MarianModel",
|
||||
"MarianMTModel",
|
||||
"MarianPreTrainedModel",
|
||||
"MarianForCausalLM",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_marian"] = ["TFMarianMTModel", "TFMarianModel"]
|
||||
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -35,8 +35,8 @@ if is_sentencepiece_available():
|
||||
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
||||
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
|
||||
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_mbart"] = [
|
||||
|
@ -39,10 +39,10 @@ if is_tokenizers_available():
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_pegasus"] = [
|
||||
"PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"PegasusForCausalLM",
|
||||
"PegasusForConditionalGeneration",
|
||||
"PegasusModel",
|
||||
"PegasusPreTrainedModel",
|
||||
"PegasusForCausalLM",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -29,9 +29,8 @@ _import_structure = {
|
||||
}
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
|
||||
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
|
||||
|
||||
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_speech_to_text"] = [
|
||||
|
@ -22,16 +22,16 @@ from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_ava
|
||||
|
||||
_import_structure = {
|
||||
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
|
||||
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
|
||||
"feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
|
||||
"processing_wav2vec2": ["Wav2Vec2Processor"],
|
||||
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_wav2vec2"] = [
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
|
@ -1050,10 +1050,14 @@ class TFLxmertVisualFeatureEncoder:
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMarian:
|
||||
class TFMarianModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMarianMTModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
241
utils/custom_init_isort.py
Normal file
241
utils/custom_init_isort.py
Normal file
@ -0,0 +1,241 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Pattern that looks at the indentation in a line.
|
||||
_re_indent = re.compile(r"^(\s*)\S")
|
||||
# Pattern that matches `"key":" and puts `key` in group 0.
|
||||
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
|
||||
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
|
||||
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
|
||||
# Pattern that matches `"key",` and puts `key` in group 0.
|
||||
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
|
||||
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
|
||||
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
||||
|
||||
|
||||
def get_indent(line):
|
||||
"""Returns the indent in `line`."""
|
||||
search = _re_indent.search(line)
|
||||
return "" if search is None else search.groups()[0]
|
||||
|
||||
|
||||
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
|
||||
"""
|
||||
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
|
||||
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
|
||||
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
|
||||
"""
|
||||
# Let's split the code into lines and move to start_index.
|
||||
index = 0
|
||||
lines = code.split("\n")
|
||||
if start_prompt is not None:
|
||||
while not lines[index].startswith(start_prompt):
|
||||
index += 1
|
||||
blocks = ["\n".join(lines[:index])]
|
||||
else:
|
||||
blocks = []
|
||||
|
||||
# We split into blocks until we get to the `end_prompt` (or the end of the block).
|
||||
current_block = [lines[index]]
|
||||
index += 1
|
||||
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
||||
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
||||
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
||||
current_block.append(lines[index])
|
||||
blocks.append("\n".join(current_block))
|
||||
if index < len(lines) - 1:
|
||||
current_block = [lines[index + 1]]
|
||||
index += 1
|
||||
else:
|
||||
current_block = []
|
||||
else:
|
||||
blocks.append("\n".join(current_block))
|
||||
current_block = [lines[index]]
|
||||
else:
|
||||
current_block.append(lines[index])
|
||||
index += 1
|
||||
|
||||
# Adds current block if it's nonempty.
|
||||
if len(current_block) > 0:
|
||||
blocks.append("\n".join(current_block))
|
||||
|
||||
# Add final block after end_prompt if provided.
|
||||
if end_prompt is not None and index < len(lines):
|
||||
blocks.append("\n".join(lines[index:]))
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def ignore_underscore(key):
|
||||
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
|
||||
|
||||
def _inner(x):
|
||||
return key(x).lower().replace("_", "")
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def sort_objects(objects, key=None):
|
||||
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
|
||||
# If no key is provided, we use a noop.
|
||||
def noop(x):
|
||||
return x
|
||||
|
||||
if key is None:
|
||||
key = noop
|
||||
# Constants are all uppercase, they go first.
|
||||
constants = [obj for obj in objects if key(obj).isupper()]
|
||||
# Classes are not all uppercase but start with a capital, they go second.
|
||||
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
|
||||
# Functions begin with a lowercase, they go last.
|
||||
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
||||
|
||||
key1 = ignore_underscore(key)
|
||||
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
||||
|
||||
|
||||
def sort_objects_in_import(import_statement):
|
||||
"""
|
||||
Return the same `import_statement` but with objects properly sorted.
|
||||
"""
|
||||
# This inner function sort imports between [ ].
|
||||
def _replace(match):
|
||||
imports = match.groups()[0]
|
||||
if "," not in imports:
|
||||
return f"[{imports}]"
|
||||
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
|
||||
|
||||
lines = import_statement.split("\n")
|
||||
if len(lines) > 3:
|
||||
# Here we have to sort internal imports that are on several lines (one per name):
|
||||
# key: [
|
||||
# "object1",
|
||||
# "object2",
|
||||
# ...
|
||||
# ]
|
||||
|
||||
# We may have to ignore one or two lines on each side.
|
||||
idx = 2 if lines[1].strip() == "[" else 1
|
||||
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
|
||||
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
|
||||
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
|
||||
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
|
||||
elif len(lines) == 3:
|
||||
# Here we have to sort internal imports that are on one separate line:
|
||||
# key: [
|
||||
# "object1", "object2", ...
|
||||
# ]
|
||||
if _re_bracket_content.search(lines[1]) is not None:
|
||||
lines[1] = _re_bracket_content.sub(_replace, lines[1])
|
||||
else:
|
||||
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
|
||||
return "\n".join(lines)
|
||||
else:
|
||||
# Finally we have to deal with imports fitting on one line
|
||||
import_statement = _re_bracket_content.sub(_replace, import_statement)
|
||||
return import_statement
|
||||
|
||||
|
||||
def sort_imports(file, check_only=True):
|
||||
"""
|
||||
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
|
||||
"""
|
||||
with open(file, "r") as f:
|
||||
code = f.read()
|
||||
|
||||
if "_import_structure" not in code:
|
||||
return
|
||||
|
||||
# Blocks of indent level 0
|
||||
main_blocks = split_code_in_indented_blocks(
|
||||
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
|
||||
)
|
||||
|
||||
# We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
|
||||
for block_idx in range(1, len(main_blocks) - 1):
|
||||
# Check if the block contains some `_import_structure`s thingy to sort.
|
||||
block = main_blocks[block_idx]
|
||||
block_lines = block.split("\n")
|
||||
if len(block_lines) < 3 or "_import_structure" not in "".join(block_lines[:2]):
|
||||
continue
|
||||
|
||||
# Ignore first and last line: they don't contain anything.
|
||||
internal_block_code = "\n".join(block_lines[1:-1])
|
||||
indent = get_indent(block_lines[1])
|
||||
# Slit the internal block into blocks of indent level 1.
|
||||
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
|
||||
# We have two categories of import key: list or _import_structu[key].append/extend
|
||||
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
|
||||
# Grab the keys, but there is a trap: some lines are empty or jsut comments.
|
||||
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
|
||||
# We only sort the lines with a key.
|
||||
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
|
||||
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
|
||||
|
||||
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
|
||||
count = 0
|
||||
reorderded_blocks = []
|
||||
for i in range(len(internal_blocks)):
|
||||
if keys[i] is None:
|
||||
reorderded_blocks.append(internal_blocks[i])
|
||||
else:
|
||||
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
|
||||
reorderded_blocks.append(block)
|
||||
count += 1
|
||||
|
||||
# And we put our main block back together with its first and last line.
|
||||
main_blocks[block_idx] = "\n".join([block_lines[0]] + reorderded_blocks + [block_lines[-1]])
|
||||
|
||||
if code != "\n".join(main_blocks):
|
||||
if check_only:
|
||||
return True
|
||||
else:
|
||||
print(f"Overwriting {file}.")
|
||||
with open(file, "w") as f:
|
||||
f.write("\n".join(main_blocks))
|
||||
|
||||
|
||||
def sort_imports_in_all_inits(check_only=True):
|
||||
failures = []
|
||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
if "__init__.py" in files:
|
||||
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
|
||||
if result:
|
||||
failures = [os.path.join(root, "__init__.py")]
|
||||
if len(failures) > 0:
|
||||
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_imports_in_all_inits(check_only=args.check_only)
|
Loading…
Reference in New Issue
Block a user