Sort init import (#10801)

* Initial script

* Add script to properly sort imports in init.

* Add to the CI

* Update utils/custom_init_isort.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Separate scripts that change content from quality

* Move class_mapping_update to style_checks

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger 2021-03-19 16:17:13 -04:00 committed by GitHub
parent 1438c487df
commit 21e86f99e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 359 additions and 114 deletions

View File

@ -383,6 +383,7 @@ jobs:
- '~/.cache/pip'
- run: black --check examples tests src utils
- run: isort --check-only examples tests src utils
- run: python utils/custom_init_isort.py --check_only
- run: flake8 examples tests src utils
- run: python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
- run: python utils/check_copies.py

View File

@ -21,32 +21,36 @@ deps_table_update:
# Check that source code meets quality standards
extra_quality_checks: deps_table_update
extra_quality_checks:
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/style_doc.py src/transformers docs/source --max_len 119
python utils/class_mapping_update.py
# this target runs checks on all files
quality:
black --check $(check_dirs)
isort --check-only $(check_dirs)
python utils/custom_init_isort.py --check_only
flake8 $(check_dirs)
python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
${MAKE} extra_quality_checks
# Format source code automatically and check is there are any problems left that need manual fixing
style: deps_table_update
extra_style_checks: deps_table_update
python utils/custom_init_isort.py
python utils/style_doc.py src/transformers docs/source --max_len 119
python utils/class_mapping_update.py
# this target runs checks on all files
style:
black $(check_dirs)
isort $(check_dirs)
python utils/style_doc.py src/transformers docs/source --max_len 119
${MAKE} extra_style_checks
# Super fast fix and check target that only works on relevant modified files since the branch was made
fixup: modified_only_fixup extra_quality_checks
fixup: modified_only_fixup extra_style_checks extra_quality_checks
# Make marked copies of snippets of codes conform to the original

View File

@ -78,6 +78,7 @@ _import_structure = {
"xnli_processors",
"xnli_tasks_num_labels",
],
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
"file_utils": [
"CONFIG_NAME",
"MODEL_CARD_NAME",
@ -124,23 +125,8 @@ _import_structure = {
"load_tf2_model_in_pytorch_model",
"load_tf2_weights_in_pytorch_model",
],
"models": [],
# Models
"models.wav2vec2": [
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2Config",
"Wav2Vec2CTCTokenizer",
"Wav2Vec2Tokenizer",
"Wav2Vec2FeatureExtractor",
"Wav2Vec2Processor",
],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models": [],
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
"models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
@ -169,6 +155,7 @@ _import_structure = {
"BlenderbotSmallTokenizer",
],
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"],
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
@ -193,6 +180,7 @@ _import_structure = {
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"],
"models.mbart": ["MBartConfig"],
"models.mmbt": ["MMBTConfig"],
@ -207,6 +195,11 @@ _import_structure = {
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
],
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
@ -216,6 +209,14 @@ _import_structure = {
"TransfoXLCorpus",
"TransfoXLTokenizer",
],
"models.wav2vec2": [
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2Config",
"Wav2Vec2CTCTokenizer",
"Wav2Vec2FeatureExtractor",
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
@ -251,7 +252,6 @@ _import_structure = {
"SpecialTokensMixin",
"TokenSpan",
],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
"trainer_callback": [
"DefaultFlowCallback",
"EarlyStoppingCallback",
@ -383,54 +383,14 @@ if is_torch_available():
"TopPLogitsWarper",
]
_import_structure["generation_stopping_criteria"] = [
"StoppingCriteria",
"StoppingCriteriaList",
"MaxLengthCriteria",
"MaxTimeCriteria",
"StoppingCriteria",
"StoppingCriteriaList",
]
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure
_import_structure["models.speech_to_text"].extend(
[
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
"Speech2TextModel",
]
)
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
)
_import_structure["models.m2m_100"].extend(
[
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration",
"M2M100Model",
]
)
_import_structure["models.convbert"].extend(
[
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConvBertForMaskedLM",
"ConvBertForMultipleChoice",
"ConvBertForQuestionAnswering",
"ConvBertForSequenceClassification",
"ConvBertForTokenClassification",
"ConvBertLayer",
"ConvBertModel",
"ConvBertPreTrainedModel",
"load_tf_weights_in_convbert",
]
)
_import_structure["models.albert"].extend(
[
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -512,17 +472,17 @@ if is_torch_available():
_import_structure["models.blenderbot"].extend(
[
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BlenderbotForCausalLM",
"BlenderbotForConditionalGeneration",
"BlenderbotModel",
"BlenderbotForCausalLM",
]
)
_import_structure["models.blenderbot_small"].extend(
[
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
"BlenderbotSmallForCausalLM",
"BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel",
"BlenderbotSmallForCausalLM",
]
)
_import_structure["models.camembert"].extend(
@ -537,6 +497,20 @@ if is_torch_available():
"CamembertModel",
]
)
_import_structure["models.convbert"].extend(
[
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConvBertForMaskedLM",
"ConvBertForMultipleChoice",
"ConvBertForQuestionAnswering",
"ConvBertForSequenceClassification",
"ConvBertForTokenClassification",
"ConvBertLayer",
"ConvBertModel",
"ConvBertPreTrainedModel",
"load_tf_weights_in_convbert",
]
)
_import_structure["models.ctrl"].extend(
[
"CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -549,23 +523,23 @@ if is_torch_available():
_import_structure["models.deberta"].extend(
[
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification",
"DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
"DebertaForSequenceClassification",
"DebertaForTokenClassification",
"DebertaModel",
"DebertaPreTrainedModel",
]
)
_import_structure["models.deberta_v2"].extend(
[
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaV2ForSequenceClassification",
"DebertaV2Model",
"DebertaV2ForMaskedLM",
"DebertaV2PreTrainedModel",
"DebertaV2ForTokenClassification",
"DebertaV2ForQuestionAnswering",
"DebertaV2ForSequenceClassification",
"DebertaV2ForTokenClassification",
"DebertaV2Model",
"DebertaV2PreTrainedModel",
]
)
_import_structure["models.distilbert"].extend(
@ -699,7 +673,14 @@ if is_torch_available():
"LxmertXLayer",
]
)
_import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"])
_import_structure["models.m2m_100"].extend(
[
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration",
"M2M100Model",
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
_import_structure["models.mbart"].extend(
[
"MBartForCausalLM",
@ -752,7 +733,7 @@ if is_torch_available():
]
)
_import_structure["models.pegasus"].extend(
["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"]
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
)
_import_structure["models.prophetnet"].extend(
[
@ -793,6 +774,13 @@ if is_torch_available():
"RobertaModel",
]
)
_import_structure["models.speech_to_text"].extend(
[
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
"Speech2TextModel",
]
)
_import_structure["models.squeezebert"].extend(
[
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -836,6 +824,15 @@ if is_torch_available():
"load_tf_weights_in_transfo_xl",
]
)
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
)
_import_structure["models.xlm"].extend(
[
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -916,20 +913,6 @@ if is_tf_available():
"shape_list",
]
# TensorFlow models structure
_import_structure["models.convbert"].extend(
[
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFConvBertForMaskedLM",
"TFConvBertForMultipleChoice",
"TFConvBertForQuestionAnswering",
"TFConvBertForSequenceClassification",
"TFConvBertForTokenClassification",
"TFConvBertLayer",
"TFConvBertModel",
"TFConvBertPreTrainedModel",
]
)
_import_structure["models.albert"].extend(
[
"TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -1002,6 +985,19 @@ if is_tf_available():
"TFCamembertModel",
]
)
_import_structure["models.convbert"].extend(
[
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFConvBertForMaskedLM",
"TFConvBertForMultipleChoice",
"TFConvBertForQuestionAnswering",
"TFConvBertForSequenceClassification",
"TFConvBertForTokenClassification",
"TFConvBertLayer",
"TFConvBertModel",
"TFConvBertPreTrainedModel",
]
)
_import_structure["models.ctrl"].extend(
[
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -1108,7 +1104,7 @@ if is_tf_available():
"TFLxmertVisualFeatureEncoder",
]
)
_import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
_import_structure["models.mobilebert"].extend(
[
@ -2170,7 +2166,7 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder,
)
from .models.marian import TFMarian, TFMarianMTModel
from .models.marian import TFMarianModel, TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -29,10 +29,10 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_blenderbot"] = [
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BlenderbotForCausalLM",
"BlenderbotForConditionalGeneration",
"BlenderbotModel",
"BlenderbotPreTrainedModel",
"BlenderbotForCausalLM",
]

View File

@ -28,10 +28,10 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_blenderbot_small"] = [
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
"BlenderbotSmallForCausalLM",
"BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel",
"BlenderbotSmallPreTrainedModel",
"BlenderbotSmallForCausalLM",
]
if is_tf_available():

View File

@ -29,12 +29,12 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_deberta"] = [
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaForSequenceClassification",
"DebertaModel",
"DebertaForMaskedLM",
"DebertaPreTrainedModel",
"DebertaForTokenClassification",
"DebertaForQuestionAnswering",
"DebertaForSequenceClassification",
"DebertaForTokenClassification",
"DebertaModel",
"DebertaPreTrainedModel",
]

View File

@ -29,12 +29,12 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_deberta_v2"] = [
"DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST",
"DebertaV2ForSequenceClassification",
"DebertaV2Model",
"DebertaV2ForMaskedLM",
"DebertaV2PreTrainedModel",
"DebertaV2ForTokenClassification",
"DebertaV2ForQuestionAnswering",
"DebertaV2ForSequenceClassification",
"DebertaV2ForTokenClassification",
"DebertaV2Model",
"DebertaV2PreTrainedModel",
]

View File

@ -28,13 +28,13 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_ibert"] = [
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertPreTrainedModel",
"IBertForMaskedLM",
"IBertForMultipleChoice",
"IBertForQuestionAnswering",
"IBertForSequenceClassification",
"IBertForTokenClassification",
"IBertModel",
"IBertPreTrainedModel",
]
if TYPE_CHECKING:

View File

@ -36,14 +36,14 @@ if is_sentencepiece_available():
if is_torch_available():
_import_structure["modeling_marian"] = [
"MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST",
"MarianForCausalLM",
"MarianModel",
"MarianMTModel",
"MarianPreTrainedModel",
"MarianForCausalLM",
]
if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianMTModel", "TFMarianModel"]
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
if TYPE_CHECKING:

View File

@ -35,8 +35,8 @@ if is_sentencepiece_available():
_import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"]
if is_tokenizers_available():
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
_import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"]
_import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"]
if is_torch_available():
_import_structure["modeling_mbart"] = [

View File

@ -39,10 +39,10 @@ if is_tokenizers_available():
if is_torch_available():
_import_structure["modeling_pegasus"] = [
"PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST",
"PegasusForCausalLM",
"PegasusForConditionalGeneration",
"PegasusModel",
"PegasusPreTrainedModel",
"PegasusForCausalLM",
]
if is_tf_available():

View File

@ -29,9 +29,8 @@ _import_structure = {
}
if is_sentencepiece_available():
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
if is_torch_available():
_import_structure["modeling_speech_to_text"] = [

View File

@ -22,16 +22,16 @@ from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_ava
_import_structure = {
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
"feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"],
"processing_wav2vec2": ["Wav2Vec2Processor"],
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
}
if is_torch_available():
_import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]

View File

@ -1050,10 +1050,14 @@ class TFLxmertVisualFeatureEncoder:
requires_tf(self)
class TFMarian:
class TFMarianModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFMarianMTModel:
def __init__(self, *args, **kwargs):

241
utils/custom_init_isort.py Normal file
View File

@ -0,0 +1,241 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
PATH_TO_TRANSFORMERS = "src/transformers"
# Pattern that looks at the indentation in a line.
_re_indent = re.compile(r"^(\s*)\S")
# Pattern that matches `"key":" and puts `key` in group 0.
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
# Pattern that matches `"key",` and puts `key` in group 0.
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
def get_indent(line):
"""Returns the indent in `line`."""
search = _re_indent.search(line)
return "" if search is None else search.groups()[0]
def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None):
"""
Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after
`start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's
after `end_prompt` as a last block, so `code` is always the same as joining the result of this function).
"""
# Let's split the code into lines and move to start_index.
index = 0
lines = code.split("\n")
if start_prompt is not None:
while not lines[index].startswith(start_prompt):
index += 1
blocks = ["\n".join(lines[:index])]
else:
blocks = []
# We split into blocks until we get to the `end_prompt` (or the end of the block).
current_block = [lines[index]]
index += 1
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
current_block.append(lines[index])
blocks.append("\n".join(current_block))
if index < len(lines) - 1:
current_block = [lines[index + 1]]
index += 1
else:
current_block = []
else:
blocks.append("\n".join(current_block))
current_block = [lines[index]]
else:
current_block.append(lines[index])
index += 1
# Adds current block if it's nonempty.
if len(current_block) > 0:
blocks.append("\n".join(current_block))
# Add final block after end_prompt if provided.
if end_prompt is not None and index < len(lines):
blocks.append("\n".join(lines[index:]))
return blocks
def ignore_underscore(key):
"Wraps a `key` (that maps an object to string) to lower case and remove underscores."
def _inner(x):
return key(x).lower().replace("_", "")
return _inner
def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
# If no key is provided, we use a noop.
def noop(x):
return x
if key is None:
key = noop
# Constants are all uppercase, they go first.
constants = [obj for obj in objects if key(obj).isupper()]
# Classes are not all uppercase but start with a capital, they go second.
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
# Functions begin with a lowercase, they go last.
functions = [obj for obj in objects if not key(obj)[0].isupper()]
key1 = ignore_underscore(key)
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
def sort_objects_in_import(import_statement):
"""
Return the same `import_statement` but with objects properly sorted.
"""
# This inner function sort imports between [ ].
def _replace(match):
imports = match.groups()[0]
if "," not in imports:
return f"[{imports}]"
keys = [part.strip().replace('"', "") for part in imports.split(",")]
# We will have a final empty element if the line finished with a comma.
if len(keys[-1]) == 0:
keys = keys[:-1]
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
lines = import_statement.split("\n")
if len(lines) > 3:
# Here we have to sort internal imports that are on several lines (one per name):
# key: [
# "object1",
# "object2",
# ...
# ]
# We may have to ignore one or two lines on each side.
idx = 2 if lines[1].strip() == "[" else 1
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
elif len(lines) == 3:
# Here we have to sort internal imports that are on one separate line:
# key: [
# "object1", "object2", ...
# ]
if _re_bracket_content.search(lines[1]) is not None:
lines[1] = _re_bracket_content.sub(_replace, lines[1])
else:
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
# We will have a final empty element if the line finished with a comma.
if len(keys[-1]) == 0:
keys = keys[:-1]
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
return "\n".join(lines)
else:
# Finally we have to deal with imports fitting on one line
import_statement = _re_bracket_content.sub(_replace, import_statement)
return import_statement
def sort_imports(file, check_only=True):
"""
Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite.
"""
with open(file, "r") as f:
code = f.read()
if "_import_structure" not in code:
return
# Blocks of indent level 0
main_blocks = split_code_in_indented_blocks(
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
)
# We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
for block_idx in range(1, len(main_blocks) - 1):
# Check if the block contains some `_import_structure`s thingy to sort.
block = main_blocks[block_idx]
block_lines = block.split("\n")
if len(block_lines) < 3 or "_import_structure" not in "".join(block_lines[:2]):
continue
# Ignore first and last line: they don't contain anything.
internal_block_code = "\n".join(block_lines[1:-1])
indent = get_indent(block_lines[1])
# Slit the internal block into blocks of indent level 1.
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
# We have two categories of import key: list or _import_structu[key].append/extend
pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key
# Grab the keys, but there is a trap: some lines are empty or jsut comments.
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
# We only sort the lines with a key.
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
count = 0
reorderded_blocks = []
for i in range(len(internal_blocks)):
if keys[i] is None:
reorderded_blocks.append(internal_blocks[i])
else:
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
reorderded_blocks.append(block)
count += 1
# And we put our main block back together with its first and last line.
main_blocks[block_idx] = "\n".join([block_lines[0]] + reorderded_blocks + [block_lines[-1]])
if code != "\n".join(main_blocks):
if check_only:
return True
else:
print(f"Overwriting {file}.")
with open(file, "w") as f:
f.write("\n".join(main_blocks))
def sort_imports_in_all_inits(check_only=True):
failures = []
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
if "__init__.py" in files:
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
if result:
failures = [os.path.join(root, "__init__.py")]
if len(failures) > 0:
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_imports_in_all_inits(check_only=args.check_only)