transformers/utils/check_dummies.py
Sylvain Gugger 758ed3332b
Transformers fast import part 2 (#9446)
* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
2021-01-07 09:36:14 -05:00

185 lines
6.4 KiB
Python

# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_{1}(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_{1}(self)
"""
DUMMY_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_{1}(self)
"""
DUMMY_FUNCTION = """
def {0}(*args, **kwargs):
requires_{1}({0})
"""
def read_init():
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Get to the point we do the actual imports for type checking
line_index = 0
while not lines[line_index].startswith("if TYPE_CHECKING"):
line_index += 1
backend_specific_objects = {}
# Go through the end of the file
while line_index < len(lines):
# If the line is an if is_backemd_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
line_index += 1
# Ignore if backend isn't tracked for dummies.
if backend not in BACKENDS:
continue
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line = lines[line_index]
single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
objects.append(line[12:-2])
line_index += 1
backend_specific_objects[backend] = objects
else:
line_index += 1
return backend_specific_objects
def create_dummy_object(name, backend_name):
""" Create the code for the dummy object corresponding to `name`."""
_pretrained = [
"Config" "ForCausalLM",
"ForConditionalGeneration",
"ForMaskedLM",
"ForMultipleChoice",
"ForQuestionAnswering",
"ForSequenceClassification",
"ForTokenClassification",
"Model",
"Tokenizer",
]
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name)
else:
is_pretrained = False
for part in _pretrained:
if part in name:
is_pretrained = True
break
if is_pretrained:
return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
else:
return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files():
""" Create the content of the dummy files. """
backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
module_names = {"torch": "pytorch"}
dummy_files = {}
for backend, objects in backend_specific_objects.items():
backend_name = module_names.get(backend, backend)
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file
return dummy_files
def check_dummies(overwrite=False):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
dummy_files = create_dummy_files()
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content.
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
dummy_file_paths = {
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
for backend in dummy_files.keys()
}
actual_dummies = {}
for backend, file_path in dummy_file_paths.items():
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_dummies[backend] = f.read()
for backend in dummy_files.keys():
if dummy_files[backend] != actual_dummies[backend]:
if overwrite:
print(
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
"__init__ has new objects."
)
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
f.write(dummy_files[backend])
else:
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
"to fix this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_dummies(args.fix_and_overwrite)