# coding=utf-8 # Copyright 2020 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import re # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_dummies.py PATH_TO_TRANSFORMERS = "src/transformers" _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") DUMMY_CONSTANT = """ {0} = None """ DUMMY_PT_PRETRAINED_CLASS = """ class {0}: def __init__(self, *args, **kwargs): requires_pytorch(self) @classmethod def from_pretrained(self, *args, **kwargs): requires_pytorch(self) """ DUMMY_PT_CLASS = """ class {0}: def __init__(self, *args, **kwargs): requires_pytorch(self) """ DUMMY_PT_FUNCTION = """ def {0}(*args, **kwargs): requires_pytorch({0}) """ DUMMY_TF_PRETRAINED_CLASS = """ class {0}: def __init__(self, *args, **kwargs): requires_tf(self) @classmethod def from_pretrained(self, *args, **kwargs): requires_tf(self) """ DUMMY_TF_CLASS = """ class {0}: def __init__(self, *args, **kwargs): requires_tf(self) """ DUMMY_TF_FUNCTION = """ def {0}(*args, **kwargs): requires_tf({0}) """ def read_init(): """ Read the init and exctracts PyTorch and TensorFlow objects. """ with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f: lines = f.readlines() line_index = 0 # Find where the PyTorch imports begin pt_objects = [] while not lines[line_index].startswith("if is_torch_available():"): line_index += 1 line_index += 1 # Until we unindent, add PyTorch objects to the list while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): line = lines[line_index] search = _re_single_line_import.search(line) if search is not None: pt_objects += search.groups()[0].split(", ") elif line.startswith(" "): pt_objects.append(line[8:-2]) line_index += 1 # Find where the TF imports begin tf_objects = [] while not lines[line_index].startswith("if is_tf_available():"): line_index += 1 line_index += 1 # Until we unindent, add PyTorch objects to the list while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): line = lines[line_index] search = _re_single_line_import.search(line) if search is not None: tf_objects += search.groups()[0].split(", ") elif line.startswith(" "): tf_objects.append(line[8:-2]) line_index += 1 return pt_objects, tf_objects def create_dummy_object(name, is_pytorch=True): """ Create the code for the dummy object corresponding to `name`.""" _pretrained = [ "Config" "ForCausalLM", "ForConditionalGeneration", "ForMaskedLM", "ForMultipleChoice", "ForQuestionAnswering", "ForSequenceClassification", "ForTokenClassification", "Model", "Tokenizer", ] if name.isupper(): return DUMMY_CONSTANT.format(name) elif name.islower(): return (DUMMY_PT_FUNCTION if is_pytorch else DUMMY_TF_FUNCTION).format(name) else: is_pretrained = False for part in _pretrained: if part in name: is_pretrained = True break if is_pretrained: template = DUMMY_PT_PRETRAINED_CLASS if is_pytorch else DUMMY_TF_PRETRAINED_CLASS else: template = DUMMY_PT_CLASS if is_pytorch else DUMMY_TF_CLASS return template.format(name) def create_dummy_files(): """ Create the content of the dummy files. """ pt_objects, tf_objects = read_init() pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" pt_dummies += "from ..file_utils import requires_pytorch\n\n" pt_dummies += "\n".join([create_dummy_object(o) for o in pt_objects]) tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" tf_dummies += "from ..file_utils import requires_tf\n\n" tf_dummies += "\n".join([create_dummy_object(o, False) for o in tf_objects]) return pt_dummies, tf_dummies def check_dummies(overwrite=False): """ Check if the dummy files are up to date and maybe `overwrite` with the right content. """ pt_dummies, tf_dummies = create_dummy_files() path = os.path.join(PATH_TO_TRANSFORMERS, "utils") pt_file = os.path.join(path, "dummy_pt_objects.py") tf_file = os.path.join(path, "dummy_tf_objects.py") with open(pt_file, "r", encoding="utf-8") as f: actual_pt_dummies = f.read() with open(tf_file, "r", encoding="utf-8") as f: actual_tf_dummies = f.read() if pt_dummies != actual_pt_dummies: if overwrite: print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.") with open(pt_file, "w", encoding="utf-8") as f: f.write(pt_dummies) else: raise ValueError( "The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.", "Run `make fix-copies` to fix this.", ) if tf_dummies != actual_tf_dummies: if overwrite: print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.") with open(tf_file, "w", encoding="utf-8") as f: f.write(tf_dummies) else: raise ValueError( "The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.", "Run `make fix-copies` to fix this.", ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() check_dummies(args.fix_and_overwrite)