mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add a script to check inits are consistent (#11024)
This commit is contained in:
parent
335c0ca35c
commit
b0d49fd536
@ -405,6 +405,7 @@ jobs:
|
||||
- run: python utils/check_table.py
|
||||
- run: python utils/check_dummies.py
|
||||
- run: python utils/check_repo.py
|
||||
- run: python utils/check_inits.py
|
||||
|
||||
check_repository_consistency:
|
||||
working_directory: ~/transformers
|
||||
|
1
Makefile
1
Makefile
@ -31,6 +31,7 @@ extra_quality_checks:
|
||||
python utils/check_table.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_inits.py
|
||||
|
||||
# this target runs checks on all files
|
||||
quality:
|
||||
|
@ -1552,6 +1552,7 @@ if TYPE_CHECKING:
|
||||
from .training_args import TrainingArguments
|
||||
from .training_args_seq2seq import Seq2SeqTrainingArguments
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .models.albert import AlbertTokenizer
|
||||
@ -1662,6 +1663,12 @@ if TYPE_CHECKING:
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from .generation_stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from .generation_utils import top_k_top_p_filtering
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer
|
||||
from .models.albert import (
|
||||
@ -1887,6 +1894,7 @@ if TYPE_CHECKING:
|
||||
IBertForSequenceClassification,
|
||||
IBertForTokenClassification,
|
||||
IBertModel,
|
||||
IBertPreTrainedModel,
|
||||
)
|
||||
from .models.layoutlm import (
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -17,17 +17,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
||||
"tokenization_gpt_neo": ["GPTNeoTokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_gpt_neo_fast"] = ["GPTNeoTokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_gpt_neo"] = [
|
||||
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
|
@ -41,6 +41,12 @@ _import_structure = {
|
||||
"configuration_mt5": ["MT5Config"],
|
||||
}
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["."] = ["T5Tokenizer"] # Fake to get the same objects in both side.
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["."] = ["T5TokenizerFast"] # Fake to get the same objects in both side.
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
|
||||
|
||||
|
@ -198,6 +198,26 @@ class TopPLogitsWarper:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MaxLengthCriteria:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MaxTimeCriteria:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class StoppingCriteria:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class StoppingCriteriaList:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_pytorch(top_k_top_p_filtering)
|
||||
|
||||
@ -1539,6 +1559,15 @@ class IBertModel:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class IBertPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
191
utils/check_inits.py
Normal file
191
utils/check_inits.py
Normal file
@ -0,0 +1,191 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"]
|
||||
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
# Catches a line if is_foo_available
|
||||
_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
# Catches a line _import_struct["bla"].append("foo")
|
||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
|
||||
# Catches a line with an object between quotes and a comma: "MyModel",
|
||||
_re_quote_object = re.compile('^\s+"([^"]+)",')
|
||||
# Catches a line with objects between brackets only: ["foo", "bar"],
|
||||
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
|
||||
# Catches a line with from foo import bar, bla, boo
|
||||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
|
||||
|
||||
def parse_init(init_file):
|
||||
"""
|
||||
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
|
||||
defined
|
||||
"""
|
||||
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
line_index = 0
|
||||
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
|
||||
line_index += 1
|
||||
|
||||
# If this is a traditional init, just return.
|
||||
if line_index >= len(lines):
|
||||
return None
|
||||
|
||||
# First grab the objects without a specific backend in _import_structure
|
||||
objects = []
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None:
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import_struct_key_value.search(line)
|
||||
if single_line_import_search is not None:
|
||||
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
line_index += 1
|
||||
|
||||
import_dict_objects = {"none": objects}
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||
line = lines[line_index]
|
||||
if _re_import_struct_add_one.search(line) is not None:
|
||||
objects.append(_re_import_struct_add_one.search(line).groups()[0])
|
||||
elif _re_import_struct_add_many.search(line) is not None:
|
||||
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_between_brackets.search(line) is not None:
|
||||
imports = _re_between_brackets.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_quote_object.search(line) is not None:
|
||||
objects.append(_re_quote_object.search(line).groups()[0])
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
elif line.startswith(" " * 12 + '"'):
|
||||
objects.append(line[13:-3])
|
||||
line_index += 1
|
||||
|
||||
import_dict_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
|
||||
objects = []
|
||||
while (
|
||||
line_index < len(lines)
|
||||
and _re_test_backend.search(lines[line_index]) is None
|
||||
and not lines[line_index].startswith("else")
|
||||
):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 8):
|
||||
objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects = {"none": objects}
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backemd_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
line_index += 1
|
||||
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return import_dict_objects, type_hint_objects
|
||||
|
||||
|
||||
def analyze_results(import_dict_objects, type_hint_objects):
|
||||
"""
|
||||
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
|
||||
"""
|
||||
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
|
||||
return ["Both sides of the init do not have the same backends!"]
|
||||
|
||||
errors = []
|
||||
for key in import_dict_objects.keys():
|
||||
if sorted(import_dict_objects[key]) != sorted(type_hint_objects[key]):
|
||||
name = "base imports" if key == "none" else f"{key} backend"
|
||||
errors.append(f"Differences for {name}:")
|
||||
for a in type_hint_objects[key]:
|
||||
if a not in import_dict_objects[key]:
|
||||
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
|
||||
for a in import_dict_objects[key]:
|
||||
if a not in type_hint_objects[key]:
|
||||
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
|
||||
return errors
|
||||
|
||||
|
||||
def check_all_inits():
|
||||
"""
|
||||
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in
|
||||
both halves.
|
||||
"""
|
||||
failures = []
|
||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
if "__init__.py" in files:
|
||||
fname = os.path.join(root, "__init__.py")
|
||||
objects = parse_init(fname)
|
||||
if objects is not None:
|
||||
errors = analyze_results(*objects)
|
||||
if len(errors) > 0:
|
||||
errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}"
|
||||
failures.append("\n".join(errors))
|
||||
if len(failures) > 0:
|
||||
raise ValueError("\n\n".join(failures))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_all_inits()
|
Loading…
Reference in New Issue
Block a user