From b0d49fd5363429659d9b494d4349fefc8577e788 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Sun, 4 Apr 2021 20:41:34 -0400 Subject: [PATCH] Add a script to check inits are consistent (#11024) --- .circleci/config.yml | 1 + Makefile | 1 + src/transformers/__init__.py | 8 + src/transformers/models/gpt_neo/__init__.py | 6 +- src/transformers/models/mt5/__init__.py | 6 + src/transformers/utils/dummy_pt_objects.py | 29 +++ utils/check_inits.py | 191 ++++++++++++++++++++ 7 files changed, 237 insertions(+), 5 deletions(-) create mode 100644 utils/check_inits.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 56d551a9465..999af392fbb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -405,6 +405,7 @@ jobs: - run: python utils/check_table.py - run: python utils/check_dummies.py - run: python utils/check_repo.py + - run: python utils/check_inits.py check_repository_consistency: working_directory: ~/transformers diff --git a/Makefile b/Makefile index 6a09470050a..8661da61c38 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,7 @@ extra_quality_checks: python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py + python utils/check_inits.py # this target runs checks on all files quality: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f5954696e9b..bfba435588a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1552,6 +1552,7 @@ if TYPE_CHECKING: from .training_args import TrainingArguments from .training_args_seq2seq import Seq2SeqTrainingArguments from .training_args_tf import TFTrainingArguments + from .utils import logging if is_sentencepiece_available(): from .models.albert import AlbertTokenizer @@ -1662,6 +1663,12 @@ if TYPE_CHECKING: TopKLogitsWarper, TopPLogitsWarper, ) + from .generation_stopping_criteria import ( + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + ) from .generation_utils import top_k_top_p_filtering from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer from .models.albert import ( @@ -1887,6 +1894,7 @@ if TYPE_CHECKING: IBertForSequenceClassification, IBertForTokenClassification, IBertModel, + IBertPreTrainedModel, ) from .models.layoutlm import ( LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/gpt_neo/__init__.py b/src/transformers/models/gpt_neo/__init__.py index 47365597448..7ce86116d60 100644 --- a/src/transformers/models/gpt_neo/__init__.py +++ b/src/transformers/models/gpt_neo/__init__.py @@ -17,17 +17,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available +from ...file_utils import _BaseLazyModule, is_torch_available _import_structure = { "configuration_gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], - "tokenization_gpt_neo": ["GPTNeoTokenizer"], } -if is_tokenizers_available(): - _import_structure["tokenization_gpt_neo_fast"] = ["GPTNeoTokenizerFast"] - if is_torch_available(): _import_structure["modeling_gpt_neo"] = [ "GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST", diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index c72aa3411a7..b4b44499562 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -41,6 +41,12 @@ _import_structure = { "configuration_mt5": ["MT5Config"], } +if is_sentencepiece_available(): + _import_structure["."] = ["T5Tokenizer"] # Fake to get the same objects in both side. + +if is_tokenizers_available(): + _import_structure["."] = ["T5TokenizerFast"] # Fake to get the same objects in both side. + if is_torch_available(): _import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 59649a3c02b..942d267cfad 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -198,6 +198,26 @@ class TopPLogitsWarper: requires_pytorch(self) +class MaxLengthCriteria: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class MaxTimeCriteria: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class StoppingCriteria: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class StoppingCriteriaList: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + def top_k_top_p_filtering(*args, **kwargs): requires_pytorch(top_k_top_p_filtering) @@ -1539,6 +1559,15 @@ class IBertModel: requires_pytorch(self) +class IBertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/utils/check_inits.py b/utils/check_inits.py new file mode 100644 index 00000000000..7d024ed3951 --- /dev/null +++ b/utils/check_inits.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + + +PATH_TO_TRANSFORMERS = "src/transformers" +BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"] + +# Catches a line with a key-values pattern: "bla": ["foo", "bar"] +_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') +# Catches a line if is_foo_available +_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$") +# Catches a line _import_struct["bla"].append("foo") +_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') +# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] +_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]") +# Catches a line with an object between quotes and a comma: "MyModel", +_re_quote_object = re.compile('^\s+"([^"]+)",') +# Catches a line with objects between brackets only: ["foo", "bar"], +_re_between_brackets = re.compile("^\s+\[([^\]]+)\]") +# Catches a line with from foo import bar, bla, boo +_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") + + +def parse_init(init_file): + """ + Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects + defined + """ + with open(init_file, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + line_index = 0 + while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"): + line_index += 1 + + # If this is a traditional init, just return. + if line_index >= len(lines): + return None + + # First grab the objects without a specific backend in _import_structure + objects = [] + while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None: + line = lines[line_index] + single_line_import_search = _re_import_struct_key_value.search(line) + if single_line_import_search is not None: + imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0] + objects.extend(imports) + elif line.startswith(" " * 8 + '"'): + objects.append(line[9:-3]) + line_index += 1 + + import_dict_objects = {"none": objects} + # Let's continue with backend-specific objects in _import_structure + while not lines[line_index].startswith("if TYPE_CHECKING"): + # If the line is an if is_backend_available, we grab all objects associated. + if _re_test_backend.search(lines[line_index]) is not None: + backend = _re_test_backend.search(lines[line_index]).groups()[0] + line_index += 1 + + # Ignore if backend isn't tracked for dummies. + if backend not in BACKENDS: + continue + + objects = [] + # Until we unindent, add backend objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): + line = lines[line_index] + if _re_import_struct_add_one.search(line) is not None: + objects.append(_re_import_struct_add_one.search(line).groups()[0]) + elif _re_import_struct_add_many.search(line) is not None: + imports = _re_import_struct_add_many.search(line).groups()[0].split(", ") + imports = [obj[1:-1] for obj in imports if len(obj) > 0] + objects.extend(imports) + elif _re_between_brackets.search(line) is not None: + imports = _re_between_brackets.search(line).groups()[0].split(", ") + imports = [obj[1:-1] for obj in imports if len(obj) > 0] + objects.extend(imports) + elif _re_quote_object.search(line) is not None: + objects.append(_re_quote_object.search(line).groups()[0]) + elif line.startswith(" " * 8 + '"'): + objects.append(line[9:-3]) + elif line.startswith(" " * 12 + '"'): + objects.append(line[13:-3]) + line_index += 1 + + import_dict_objects[backend] = objects + else: + line_index += 1 + + # At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend + objects = [] + while ( + line_index < len(lines) + and _re_test_backend.search(lines[line_index]) is None + and not lines[line_index].startswith("else") + ): + line = lines[line_index] + single_line_import_search = _re_import.search(line) + if single_line_import_search is not None: + objects.extend(single_line_import_search.groups()[0].split(", ")) + elif line.startswith(" " * 8): + objects.append(line[8:-2]) + line_index += 1 + + type_hint_objects = {"none": objects} + # Let's continue with backend-specific objects + while line_index < len(lines): + # If the line is an if is_backemd_available, we grab all objects associated. + if _re_test_backend.search(lines[line_index]) is not None: + backend = _re_test_backend.search(lines[line_index]).groups()[0] + line_index += 1 + + # Ignore if backend isn't tracked for dummies. + if backend not in BACKENDS: + continue + + objects = [] + # Until we unindent, add backend objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): + line = lines[line_index] + single_line_import_search = _re_import.search(line) + if single_line_import_search is not None: + objects.extend(single_line_import_search.groups()[0].split(", ")) + elif line.startswith(" " * 12): + objects.append(line[12:-2]) + line_index += 1 + + type_hint_objects[backend] = objects + else: + line_index += 1 + + return import_dict_objects, type_hint_objects + + +def analyze_results(import_dict_objects, type_hint_objects): + """ + Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init. + """ + if list(import_dict_objects.keys()) != list(type_hint_objects.keys()): + return ["Both sides of the init do not have the same backends!"] + + errors = [] + for key in import_dict_objects.keys(): + if sorted(import_dict_objects[key]) != sorted(type_hint_objects[key]): + name = "base imports" if key == "none" else f"{key} backend" + errors.append(f"Differences for {name}:") + for a in type_hint_objects[key]: + if a not in import_dict_objects[key]: + errors.append(f" {a} in TYPE_HINT but not in _import_structure.") + for a in import_dict_objects[key]: + if a not in type_hint_objects[key]: + errors.append(f" {a} in _import_structure but not in TYPE_HINT.") + return errors + + +def check_all_inits(): + """ + Check all inits in the transformers repo and raise an error if at least one does not define the same objects in + both halves. + """ + failures = [] + for root, _, files in os.walk(PATH_TO_TRANSFORMERS): + if "__init__.py" in files: + fname = os.path.join(root, "__init__.py") + objects = parse_init(fname) + if objects is not None: + errors = analyze_results(*objects) + if len(errors) > 0: + errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}" + failures.append("\n".join(errors)) + if len(failures) > 0: + raise ValueError("\n\n".join(failures)) + + +if __name__ == "__main__": + check_all_inits()