mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Sort init import (#10801)
* Initial script * Add script to properly sort imports in init. * Add to the CI * Update utils/custom_init_isort.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Separate scripts that change content from quality * Move class_mapping_update to style_checks Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
1438c487df
commit
21e86f99e6
@ -383,6 +383,7 @@ jobs:
|
|||||||
- '~/.cache/pip'
|
- '~/.cache/pip'
|
||||||
- run: black --check examples tests src utils
|
- run: black --check examples tests src utils
|
||||||
- run: isort --check-only examples tests src utils
|
- run: isort --check-only examples tests src utils
|
||||||
|
- run: python utils/custom_init_isort.py --check_only
|
||||||
- run: flake8 examples tests src utils
|
- run: flake8 examples tests src utils
|
||||||
- run: python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
|
- run: python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
|
||||||
- run: python utils/check_copies.py
|
- run: python utils/check_copies.py
|
||||||
|
18
Makefile
18
Makefile
@ -21,32 +21,36 @@ deps_table_update:
|
|||||||
|
|
||||||
# Check that source code meets quality standards
|
# Check that source code meets quality standards
|
||||||
|
|
||||||
extra_quality_checks: deps_table_update
|
extra_quality_checks:
|
||||||
python utils/check_copies.py
|
python utils/check_copies.py
|
||||||
python utils/check_table.py
|
python utils/check_table.py
|
||||||
python utils/check_dummies.py
|
python utils/check_dummies.py
|
||||||
python utils/check_repo.py
|
python utils/check_repo.py
|
||||||
python utils/style_doc.py src/transformers docs/source --max_len 119
|
|
||||||
python utils/class_mapping_update.py
|
|
||||||
|
|
||||||
# this target runs checks on all files
|
# this target runs checks on all files
|
||||||
quality:
|
quality:
|
||||||
black --check $(check_dirs)
|
black --check $(check_dirs)
|
||||||
isort --check-only $(check_dirs)
|
isort --check-only $(check_dirs)
|
||||||
|
python utils/custom_init_isort.py --check_only
|
||||||
flake8 $(check_dirs)
|
flake8 $(check_dirs)
|
||||||
python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
|
|
||||||
${MAKE} extra_quality_checks
|
${MAKE} extra_quality_checks
|
||||||
|
|
||||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||||
|
|
||||||
style: deps_table_update
|
extra_style_checks: deps_table_update
|
||||||
|
python utils/custom_init_isort.py
|
||||||
|
python utils/style_doc.py src/transformers docs/source --max_len 119
|
||||||
|
python utils/class_mapping_update.py
|
||||||
|
|
||||||
|
# this target runs checks on all files
|
||||||
|
style:
|
||||||
black $(check_dirs)
|
black $(check_dirs)
|
||||||
isort $(check_dirs)
|
isort $(check_dirs)
|
||||||
python utils/style_doc.py src/transformers docs/source --max_len 119
|
${MAKE} extra_style_checks
|
||||||
|
|
||||||
# Super fast fix and check target that only works on relevant modified files since the branch was made
|
# Super fast fix and check target that only works on relevant modified files since the branch was made
|
||||||
|
|
||||||
fixup: modified_only_fixup extra_quality_checks
|
fixup: modified_only_fixup extra_style_checks extra_quality_checks
|
||||||
|
|
||||||
# Make marked copies of snippets of codes conform to the original
|
# Make marked copies of snippets of codes conform to the original
|
||||||
|
|
||||||
|
@ -78,6 +78,7 @@ _import_structure = {
|
|||||||
"xnli_processors",
|
"xnli_processors",
|
||||||
"xnli_tasks_num_labels",
|
"xnli_tasks_num_labels",
|
||||||
],
|
],
|
||||||
|
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
|
||||||
"file_utils": [
|
"file_utils": [
|
||||||
"CONFIG_NAME",
|
"CONFIG_NAME",
|
||||||
"MODEL_CARD_NAME",
|
"MODEL_CARD_NAME",
|
||||||
@ -124,23 +125,8 @@ _import_structure = {
|
|||||||
"load_tf2_model_in_pytorch_model",
|
"load_tf2_model_in_pytorch_model",
|
||||||
"load_tf2_weights_in_pytorch_model",
|
"load_tf2_weights_in_pytorch_model",
|
||||||
],
|
],
|
||||||
"models": [],
|
|
||||||
# Models
|
# Models
|
||||||
"models.wav2vec2": [
|
"models": [],
|
||||||
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
|
||||||
"Wav2Vec2Config",
|
|
||||||
"Wav2Vec2CTCTokenizer",
|
|
||||||
"Wav2Vec2Tokenizer",
|
|
||||||
"Wav2Vec2FeatureExtractor",
|
|
||||||
"Wav2Vec2Processor",
|
|
||||||
],
|
|
||||||
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
|
||||||
"models.speech_to_text": [
|
|
||||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
|
||||||
"Speech2TextConfig",
|
|
||||||
"Speech2TextFeatureExtractor",
|
|
||||||
],
|
|
||||||
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
|
||||||
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
||||||
"models.auto": [
|
"models.auto": [
|
||||||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
@ -169,6 +155,7 @@ _import_structure = {
|
|||||||
"BlenderbotSmallTokenizer",
|
"BlenderbotSmallTokenizer",
|
||||||
],
|
],
|
||||||
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
|
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
|
||||||
|
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
||||||
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
|
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
|
||||||
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"],
|
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"],
|
||||||
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
|
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
|
||||||
@ -193,6 +180,7 @@ _import_structure = {
|
|||||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||||
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
|
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
|
||||||
|
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
||||||
"models.marian": ["MarianConfig"],
|
"models.marian": ["MarianConfig"],
|
||||||
"models.mbart": ["MBartConfig"],
|
"models.mbart": ["MBartConfig"],
|
||||||
"models.mmbt": ["MMBTConfig"],
|
"models.mmbt": ["MMBTConfig"],
|
||||||
@ -207,6 +195,11 @@ _import_structure = {
|
|||||||
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
|
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
|
||||||
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
|
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
|
||||||
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
|
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
|
||||||
|
"models.speech_to_text": [
|
||||||
|
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"Speech2TextConfig",
|
||||||
|
"Speech2TextFeatureExtractor",
|
||||||
|
],
|
||||||
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
|
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
|
||||||
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
||||||
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
|
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
|
||||||
@ -216,6 +209,14 @@ _import_structure = {
|
|||||||
"TransfoXLCorpus",
|
"TransfoXLCorpus",
|
||||||
"TransfoXLTokenizer",
|
"TransfoXLTokenizer",
|
||||||
],
|
],
|
||||||
|
"models.wav2vec2": [
|
||||||
|
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"Wav2Vec2Config",
|
||||||
|
"Wav2Vec2CTCTokenizer",
|
||||||
|
"Wav2Vec2FeatureExtractor",
|
||||||
|
"Wav2Vec2Processor",
|
||||||
|
"Wav2Vec2Tokenizer",
|
||||||
|
],
|
||||||
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
||||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||||
@ -251,7 +252,6 @@ _import_structure = {
|
|||||||
"SpecialTokensMixin",
|
"SpecialTokensMixin",
|
||||||
"TokenSpan",
|
"TokenSpan",
|
||||||
],
|
],
|
||||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
|
|
||||||
"trainer_callback": [
|
"trainer_callback": [
|
||||||
"DefaultFlowCallback",
|
"DefaultFlowCallback",
|
||||||
"EarlyStoppingCallback",
|
"EarlyStoppingCallback",
|
||||||
@ -383,54 +383,14 @@ if is_torch_available():
|
|||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
]
|
]
|
||||||
_import_structure["generation_stopping_criteria"] = [
|
_import_structure["generation_stopping_criteria"] = [
|
||||||
"StoppingCriteria",
|
|
||||||
"StoppingCriteriaList",
|
|
||||||
"MaxLengthCriteria",
|
"MaxLengthCriteria",
|
||||||
"MaxTimeCriteria",
|
"MaxTimeCriteria",
|
||||||
|
"StoppingCriteria",
|
||||||
|
"StoppingCriteriaList",
|
||||||
]
|
]
|
||||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
||||||
# PyTorch models structure
|
# PyTorch models structure
|
||||||
|
|
||||||
_import_structure["models.speech_to_text"].extend(
|
|
||||||
[
|
|
||||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"Speech2TextForConditionalGeneration",
|
|
||||||
"Speech2TextModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
_import_structure["models.wav2vec2"].extend(
|
|
||||||
[
|
|
||||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"Wav2Vec2ForCTC",
|
|
||||||
"Wav2Vec2ForMaskedLM",
|
|
||||||
"Wav2Vec2Model",
|
|
||||||
"Wav2Vec2PreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.m2m_100"].extend(
|
|
||||||
[
|
|
||||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"M2M100ForConditionalGeneration",
|
|
||||||
"M2M100Model",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
_import_structure["models.convbert"].extend(
|
|
||||||
[
|
|
||||||
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"ConvBertForMaskedLM",
|
|
||||||
"ConvBertForMultipleChoice",
|
|
||||||
"ConvBertForQuestionAnswering",
|
|
||||||
"ConvBertForSequenceClassification",
|
|
||||||
"ConvBertForTokenClassification",
|
|
||||||
"ConvBertLayer",
|
|
||||||
"ConvBertModel",
|
|
||||||
"ConvBertPreTrainedModel",
|
|
||||||
"load_tf_weights_in_convbert",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.albert"].extend(
|
_import_structure["models.albert"].extend(
|
||||||
[
|
[
|
||||||
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -512,17 +472,17 @@ if is_torch_available():
|
|||||||
_import_structure["models.blenderbot"].extend(
|
_import_structure["models.blenderbot"].extend(
|
||||||
[
|
[
|
||||||
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BlenderbotForCausalLM",
|
||||||
"BlenderbotForConditionalGeneration",
|
"BlenderbotForConditionalGeneration",
|
||||||
"BlenderbotModel",
|
"BlenderbotModel",
|
||||||
"BlenderbotForCausalLM",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.blenderbot_small"].extend(
|
_import_structure["models.blenderbot_small"].extend(
|
||||||
[
|
[
|
||||||
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BlenderbotSmallForCausalLM",
|
||||||
"BlenderbotSmallForConditionalGeneration",
|
"BlenderbotSmallForConditionalGeneration",
|
||||||
"BlenderbotSmallModel",
|
"BlenderbotSmallModel",
|
||||||
"BlenderbotSmallForCausalLM",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.camembert"].extend(
|
_import_structure["models.camembert"].extend(
|
||||||
@ -537,6 +497,20 @@ if is_torch_available():
|
|||||||
"CamembertModel",
|
"CamembertModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.convbert"].extend(
|
||||||
|
[
|
||||||
|
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"ConvBertForMaskedLM",
|
||||||
|
"ConvBertForMultipleChoice",
|
||||||
|
"ConvBertForQuestionAnswering",
|
||||||
|
"ConvBertForSequenceClassification",
|
||||||
|
"ConvBertForTokenClassification",
|
||||||
|
"ConvBertLayer",
|
||||||
|
"ConvBertModel",
|
||||||
|
"ConvBertPreTrainedModel",
|
||||||
|
"load_tf_weights_in_convbert",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.ctrl"].extend(
|
_import_structure["models.ctrl"].extend(
|
||||||
[
|
[
|
||||||
"CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -549,23 +523,23 @@ if is_torch_available():
|
|||||||
_import_structure["models.deberta"].extend(
|
_import_structure["models.deberta"].extend(
|
||||||
[
|
[
|
||||||
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"DebertaForSequenceClassification",
|
|
||||||
"DebertaModel",
|
|
||||||
"DebertaForMaskedLM",
|
"DebertaForMaskedLM",
|
||||||
"DebertaPreTrainedModel",
|
|
||||||
"DebertaForTokenClassification",
|
|
||||||
"DebertaForQuestionAnswering",
|
"DebertaForQuestionAnswering",
|
||||||
|
"DebertaForSequenceClassification",
|
||||||
|
"DebertaForTokenClassification",
|
||||||
|
"DebertaModel",
|
||||||
|
"DebertaPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.deberta_v2"].extend(
|
_import_structure["models.deberta_v2"].extend(
|
||||||
[
|
[
|
||||||
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"DebertaV2ForSequenceClassification",
|
|
||||||
"DebertaV2Model",
|
|
||||||
"DebertaV2ForMaskedLM",
|
"DebertaV2ForMaskedLM",
|
||||||
"DebertaV2PreTrainedModel",
|
|
||||||
"DebertaV2ForTokenClassification",
|
|
||||||
"DebertaV2ForQuestionAnswering",
|
"DebertaV2ForQuestionAnswering",
|
||||||
|
"DebertaV2ForSequenceClassification",
|
||||||
|
"DebertaV2ForTokenClassification",
|
||||||
|
"DebertaV2Model",
|
||||||
|
"DebertaV2PreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.distilbert"].extend(
|
_import_structure["models.distilbert"].extend(
|
||||||
@ -699,7 +673,14 @@ if is_torch_available():
|
|||||||
"LxmertXLayer",
|
"LxmertXLayer",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"])
|
_import_structure["models.m2m_100"].extend(
|
||||||
|
[
|
||||||
|
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"M2M100ForConditionalGeneration",
|
||||||
|
"M2M100Model",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
|
||||||
_import_structure["models.mbart"].extend(
|
_import_structure["models.mbart"].extend(
|
||||||
[
|
[
|
||||||
"MBartForCausalLM",
|
"MBartForCausalLM",
|
||||||
@ -752,7 +733,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.pegasus"].extend(
|
_import_structure["models.pegasus"].extend(
|
||||||
["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"]
|
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
|
||||||
)
|
)
|
||||||
_import_structure["models.prophetnet"].extend(
|
_import_structure["models.prophetnet"].extend(
|
||||||
[
|
[
|
||||||
@ -793,6 +774,13 @@ if is_torch_available():
|
|||||||
"RobertaModel",
|
"RobertaModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.speech_to_text"].extend(
|
||||||
|
[
|
||||||
|
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Speech2TextForConditionalGeneration",
|
||||||
|
"Speech2TextModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.squeezebert"].extend(
|
_import_structure["models.squeezebert"].extend(
|
||||||
[
|
[
|
||||||
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -836,6 +824,15 @@ if is_torch_available():
|
|||||||
"load_tf_weights_in_transfo_xl",
|
"load_tf_weights_in_transfo_xl",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.wav2vec2"].extend(
|
||||||
|
[
|
||||||
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Wav2Vec2ForCTC",
|
||||||
|
"Wav2Vec2ForMaskedLM",
|
||||||
|
"Wav2Vec2Model",
|
||||||
|
"Wav2Vec2PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.xlm"].extend(
|
_import_structure["models.xlm"].extend(
|
||||||
[
|
[
|
||||||
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -916,20 +913,6 @@ if is_tf_available():
|
|||||||
"shape_list",
|
"shape_list",
|
||||||
]
|
]
|
||||||
# TensorFlow models structure
|
# TensorFlow models structure
|
||||||
|
|
||||||
_import_structure["models.convbert"].extend(
|
|
||||||
[
|
|
||||||
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
|
||||||
"TFConvBertForMaskedLM",
|
|
||||||
"TFConvBertForMultipleChoice",
|
|
||||||
"TFConvBertForQuestionAnswering",
|
|
||||||
"TFConvBertForSequenceClassification",
|
|
||||||
"TFConvBertForTokenClassification",
|
|
||||||
"TFConvBertLayer",
|
|
||||||
"TFConvBertModel",
|
|
||||||
"TFConvBertPreTrainedModel",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_import_structure["models.albert"].extend(
|
_import_structure["models.albert"].extend(
|
||||||
[
|
[
|
||||||
"TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -1002,6 +985,19 @@ if is_tf_available():
|
|||||||
"TFCamembertModel",
|
"TFCamembertModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.convbert"].extend(
|
||||||
|
[
|
||||||
|
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TFConvBertForMaskedLM",
|
||||||
|
"TFConvBertForMultipleChoice",
|
||||||
|
"TFConvBertForQuestionAnswering",
|
||||||
|
"TFConvBertForSequenceClassification",
|
||||||
|
"TFConvBertForTokenClassification",
|
||||||
|
"TFConvBertLayer",
|
||||||
|
"TFConvBertModel",
|
||||||
|
"TFConvBertPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.ctrl"].extend(
|
_import_structure["models.ctrl"].extend(
|
||||||
[
|
[
|
||||||
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -1108,7 +1104,7 @@ if is_tf_available():
|
|||||||
"TFLxmertVisualFeatureEncoder",
|
"TFLxmertVisualFeatureEncoder",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
|
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
|
||||||
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
|
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
|
||||||
_import_structure["models.mobilebert"].extend(
|
_import_structure["models.mobilebert"].extend(
|
||||||
[
|
[
|
||||||
@ -2170,7 +2166,7 @@ if TYPE_CHECKING:
|
|||||||
TFLxmertPreTrainedModel,
|
TFLxmertPreTrainedModel,
|
||||||
TFLxmertVisualFeatureEncoder,
|
TFLxmertVisualFeatureEncoder,
|
||||||
)
|
)
|
||||||
from .models.marian import TFMarian, TFMarianMTModel
|
from .models.marian import TFMarianModel, TFMarianMTModel
|
||||||
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
|
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||||
from .models.mobilebert import (
|
from .models.mobilebert import (
|
||||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
@ -29,10 +29,10 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_blenderbot"] = [
|
_import_structure["modeling_blenderbot"] = [
|
||||||
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BlenderbotForCausalLM",
|
||||||
"BlenderbotForConditionalGeneration",
|
"BlenderbotForConditionalGeneration",
|
||||||
"BlenderbotModel",
|
"BlenderbotModel",
|
||||||
"BlenderbotPreTrainedModel",
|
"BlenderbotPreTrainedModel",
|
||||||
"BlenderbotForCausalLM",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,10 +28,10 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_blenderbot_small"] = [
|
_import_structure["modeling_blenderbot_small"] = [
|
||||||
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"BlenderbotSmallForCausalLM",
|
||||||
"BlenderbotSmallForConditionalGeneration",
|
"BlenderbotSmallForConditionalGeneration",
|
||||||
"BlenderbotSmallModel",
|
"BlenderbotSmallModel",
|
||||||
"BlenderbotSmallPreTrainedModel",
|
"BlenderbotSmallPreTrainedModel",
|
||||||
"BlenderbotSmallForCausalLM",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
@ -29,12 +29,12 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_deberta"] = [
|
_import_structure["modeling_deberta"] = [
|
||||||
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"DebertaForSequenceClassification",
|
|
||||||
"DebertaModel",
|
|
||||||
"DebertaForMaskedLM",
|
"DebertaForMaskedLM",
|
||||||
"DebertaPreTrainedModel",
|
|
||||||
"DebertaForTokenClassification",
|
|
||||||
"DebertaForQuestionAnswering",
|
"DebertaForQuestionAnswering",
|
||||||
|
"DebertaForSequenceClassification",
|
||||||
|
"DebertaForTokenClassification",
|
||||||
|
"DebertaModel",
|
||||||
|
"DebertaPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,12 +29,12 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_deberta_v2"] = [
|
_import_structure["modeling_deberta_v2"] = [
|
||||||
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"DebertaV2ForSequenceClassification",
|
|
||||||
"DebertaV2Model",
|
|
||||||
"DebertaV2ForMaskedLM",
|
"DebertaV2ForMaskedLM",
|
||||||
"DebertaV2PreTrainedModel",
|
|
||||||
"DebertaV2ForTokenClassification",
|
|
||||||
"DebertaV2ForQuestionAnswering",
|
"DebertaV2ForQuestionAnswering",
|
||||||
|
"DebertaV2ForSequenceClassification",
|
||||||
|
"DebertaV2ForTokenClassification",
|
||||||
|
"DebertaV2Model",
|
||||||
|
"DebertaV2PreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,13 +28,13 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_ibert"] = [
|
_import_structure["modeling_ibert"] = [
|
||||||
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"IBertPreTrainedModel",
|
|
||||||
"IBertForMaskedLM",
|
"IBertForMaskedLM",
|
||||||
"IBertForMultipleChoice",
|
"IBertForMultipleChoice",
|
||||||
"IBertForQuestionAnswering",
|
"IBertForQuestionAnswering",
|
||||||
"IBertForSequenceClassification",
|
"IBertForSequenceClassification",
|
||||||
"IBertForTokenClassification",
|
"IBertForTokenClassification",
|
||||||
"IBertModel",
|
"IBertModel",
|
||||||
|
"IBertPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -36,14 +36,14 @@ if is_sentencepiece_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_marian"] = [
|
_import_structure["modeling_marian"] = [
|
||||||
"MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"MarianForCausalLM",
|
||||||
"MarianModel",
|
"MarianModel",
|
||||||
"MarianMTModel",
|
"MarianMTModel",
|
||||||
"MarianPreTrainedModel",
|
"MarianPreTrainedModel",
|
||||||
"MarianForCausalLM",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_marian"] = ["TFMarianMTModel", "TFMarianModel"]
|
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -35,8 +35,8 @@ if is_sentencepiece_available():
|
|||||||
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
|
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
|
||||||
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
|
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
|
||||||
|
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_mbart"] = [
|
_import_structure["modeling_mbart"] = [
|
||||||
|
@ -39,10 +39,10 @@ if is_tokenizers_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_pegasus"] = [
|
_import_structure["modeling_pegasus"] = [
|
||||||
"PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"PegasusForCausalLM",
|
||||||
"PegasusForConditionalGeneration",
|
"PegasusForConditionalGeneration",
|
||||||
"PegasusModel",
|
"PegasusModel",
|
||||||
"PegasusPreTrainedModel",
|
"PegasusPreTrainedModel",
|
||||||
"PegasusForCausalLM",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
@ -29,9 +29,8 @@ _import_structure = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
|
|
||||||
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
|
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
|
||||||
|
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_speech_to_text"] = [
|
_import_structure["modeling_speech_to_text"] = [
|
||||||
|
@ -22,16 +22,16 @@ from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_ava
|
|||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
|
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
|
||||||
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
|
|
||||||
"feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
|
"feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
|
||||||
"processing_wav2vec2": ["Wav2Vec2Processor"],
|
"processing_wav2vec2": ["Wav2Vec2Processor"],
|
||||||
|
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_wav2vec2"] = [
|
_import_structure["modeling_wav2vec2"] = [
|
||||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"Wav2Vec2ForMaskedLM",
|
|
||||||
"Wav2Vec2ForCTC",
|
"Wav2Vec2ForCTC",
|
||||||
|
"Wav2Vec2ForMaskedLM",
|
||||||
"Wav2Vec2Model",
|
"Wav2Vec2Model",
|
||||||
"Wav2Vec2PreTrainedModel",
|
"Wav2Vec2PreTrainedModel",
|
||||||
]
|
]
|
||||||
|
@ -1050,10 +1050,14 @@ class TFLxmertVisualFeatureEncoder:
|
|||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
class TFMarian:
|
class TFMarianModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
class TFMarianMTModel:
|
class TFMarianMTModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
241
utils/custom_init_isort.py
Normal file
241
utils/custom_init_isort.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||||
|
|
||||||
|
# Pattern that looks at the indentation in a line.
|
||||||
|
_re_indent = re.compile(r"^(\s*)\S")
|
||||||
|
# Pattern that matches `"key":" and puts `key` in group 0.
|
||||||
|
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
|
||||||
|
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
|
||||||
|
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
|
||||||
|
# Pattern that matches `"key",` and puts `key` in group 0.
|
||||||
|
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
|
||||||
|
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
|
||||||
|
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
||||||
|
|
||||||
|
|
||||||
|
def get_indent(line):
|
||||||
|
"""Returns the indent in `line`."""
|
||||||
|
search = _re_indent.search(line)
|
||||||
|
return "" if search is None else search.groups()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
|
||||||
|
"""
|
||||||
|
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
|
||||||
|
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
|
||||||
|
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
|
||||||
|
"""
|
||||||
|
# Let's split the code into lines and move to start_index.
|
||||||
|
index = 0
|
||||||
|
lines = code.split("\n")
|
||||||
|
if start_prompt is not None:
|
||||||
|
while not lines[index].startswith(start_prompt):
|
||||||
|
index += 1
|
||||||
|
blocks = ["\n".join(lines[:index])]
|
||||||
|
else:
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
# We split into blocks until we get to the `end_prompt` (or the end of the block).
|
||||||
|
current_block = [lines[index]]
|
||||||
|
index += 1
|
||||||
|
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
||||||
|
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
||||||
|
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
||||||
|
current_block.append(lines[index])
|
||||||
|
blocks.append("\n".join(current_block))
|
||||||
|
if index < len(lines) - 1:
|
||||||
|
current_block = [lines[index + 1]]
|
||||||
|
index += 1
|
||||||
|
else:
|
||||||
|
current_block = []
|
||||||
|
else:
|
||||||
|
blocks.append("\n".join(current_block))
|
||||||
|
current_block = [lines[index]]
|
||||||
|
else:
|
||||||
|
current_block.append(lines[index])
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
# Adds current block if it's nonempty.
|
||||||
|
if len(current_block) > 0:
|
||||||
|
blocks.append("\n".join(current_block))
|
||||||
|
|
||||||
|
# Add final block after end_prompt if provided.
|
||||||
|
if end_prompt is not None and index < len(lines):
|
||||||
|
blocks.append("\n".join(lines[index:]))
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def ignore_underscore(key):
|
||||||
|
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
|
||||||
|
|
||||||
|
def _inner(x):
|
||||||
|
return key(x).lower().replace("_", "")
|
||||||
|
|
||||||
|
return _inner
|
||||||
|
|
||||||
|
|
||||||
|
def sort_objects(objects, key=None):
|
||||||
|
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
|
||||||
|
# If no key is provided, we use a noop.
|
||||||
|
def noop(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
key = noop
|
||||||
|
# Constants are all uppercase, they go first.
|
||||||
|
constants = [obj for obj in objects if key(obj).isupper()]
|
||||||
|
# Classes are not all uppercase but start with a capital, they go second.
|
||||||
|
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
|
||||||
|
# Functions begin with a lowercase, they go last.
|
||||||
|
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
||||||
|
|
||||||
|
key1 = ignore_underscore(key)
|
||||||
|
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
||||||
|
|
||||||
|
|
||||||
|
def sort_objects_in_import(import_statement):
|
||||||
|
"""
|
||||||
|
Return the same `import_statement` but with objects properly sorted.
|
||||||
|
"""
|
||||||
|
# This inner function sort imports between [ ].
|
||||||
|
def _replace(match):
|
||||||
|
imports = match.groups()[0]
|
||||||
|
if "," not in imports:
|
||||||
|
return f"[{imports}]"
|
||||||
|
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
||||||
|
# We will have a final empty element if the line finished with a comma.
|
||||||
|
if len(keys[-1]) == 0:
|
||||||
|
keys = keys[:-1]
|
||||||
|
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
|
||||||
|
|
||||||
|
lines = import_statement.split("\n")
|
||||||
|
if len(lines) > 3:
|
||||||
|
# Here we have to sort internal imports that are on several lines (one per name):
|
||||||
|
# key: [
|
||||||
|
# "object1",
|
||||||
|
# "object2",
|
||||||
|
# ...
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# We may have to ignore one or two lines on each side.
|
||||||
|
idx = 2 if lines[1].strip() == "[" else 1
|
||||||
|
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
|
||||||
|
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
|
||||||
|
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
|
||||||
|
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
|
||||||
|
elif len(lines) == 3:
|
||||||
|
# Here we have to sort internal imports that are on one separate line:
|
||||||
|
# key: [
|
||||||
|
# "object1", "object2", ...
|
||||||
|
# ]
|
||||||
|
if _re_bracket_content.search(lines[1]) is not None:
|
||||||
|
lines[1] = _re_bracket_content.sub(_replace, lines[1])
|
||||||
|
else:
|
||||||
|
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
|
||||||
|
# We will have a final empty element if the line finished with a comma.
|
||||||
|
if len(keys[-1]) == 0:
|
||||||
|
keys = keys[:-1]
|
||||||
|
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
|
||||||
|
return "\n".join(lines)
|
||||||
|
else:
|
||||||
|
# Finally we have to deal with imports fitting on one line
|
||||||
|
import_statement = _re_bracket_content.sub(_replace, import_statement)
|
||||||
|
return import_statement
|
||||||
|
|
||||||
|
|
||||||
|
def sort_imports(file, check_only=True):
|
||||||
|
"""
|
||||||
|
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
|
||||||
|
"""
|
||||||
|
with open(file, "r") as f:
|
||||||
|
code = f.read()
|
||||||
|
|
||||||
|
if "_import_structure" not in code:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Blocks of indent level 0
|
||||||
|
main_blocks = split_code_in_indented_blocks(
|
||||||
|
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
|
||||||
|
for block_idx in range(1, len(main_blocks) - 1):
|
||||||
|
# Check if the block contains some `_import_structure`s thingy to sort.
|
||||||
|
block = main_blocks[block_idx]
|
||||||
|
block_lines = block.split("\n")
|
||||||
|
if len(block_lines) < 3 or "_import_structure" not in "".join(block_lines[:2]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ignore first and last line: they don't contain anything.
|
||||||
|
internal_block_code = "\n".join(block_lines[1:-1])
|
||||||
|
indent = get_indent(block_lines[1])
|
||||||
|
# Slit the internal block into blocks of indent level 1.
|
||||||
|
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
|
||||||
|
# We have two categories of import key: list or _import_structu[key].append/extend
|
||||||
|
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
|
||||||
|
# Grab the keys, but there is a trap: some lines are empty or jsut comments.
|
||||||
|
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
|
||||||
|
# We only sort the lines with a key.
|
||||||
|
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
|
||||||
|
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
|
||||||
|
|
||||||
|
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
|
||||||
|
count = 0
|
||||||
|
reorderded_blocks = []
|
||||||
|
for i in range(len(internal_blocks)):
|
||||||
|
if keys[i] is None:
|
||||||
|
reorderded_blocks.append(internal_blocks[i])
|
||||||
|
else:
|
||||||
|
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
|
||||||
|
reorderded_blocks.append(block)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# And we put our main block back together with its first and last line.
|
||||||
|
main_blocks[block_idx] = "\n".join([block_lines[0]] + reorderded_blocks + [block_lines[-1]])
|
||||||
|
|
||||||
|
if code != "\n".join(main_blocks):
|
||||||
|
if check_only:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Overwriting {file}.")
|
||||||
|
with open(file, "w") as f:
|
||||||
|
f.write("\n".join(main_blocks))
|
||||||
|
|
||||||
|
|
||||||
|
def sort_imports_in_all_inits(check_only=True):
|
||||||
|
failures = []
|
||||||
|
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||||
|
if "__init__.py" in files:
|
||||||
|
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
|
||||||
|
if result:
|
||||||
|
failures = [os.path.join(root, "__init__.py")]
|
||||||
|
if len(failures) > 0:
|
||||||
|
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
sort_imports_in_all_inits(check_only=args.check_only)
|
Loading…
Reference in New Issue
Block a user