From 81156d20cd76c1a43ed44fdbc785e237d60b6896 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 24 Jan 2022 15:25:10 -0500 Subject: [PATCH] Add model like (#14992) * Add new model like command * Bad doc-styler * black and doc-styler, stop fighting! * black and doc-styler, stop fighting! * At last * Clean up * Typo * Bad doc-styler * Bad doc-styler * All good maybe? * Use constants * Add doc and type hints * More cleaning * Add doc * Fix Copied from * Doc template * Use typing.Pattern instead * Framework-specific files * Fixes * Select frameworks clean model init * Deal with frameworks in main init * fixes * Last fix * Prompt user for info * Delete exemple config * Last fixes * Add test config * Fix bug with model_type included in each other * Fixes * More fixes * More fixes * Adapt config * Remove print statements * Will fix tokenization later, leave it broken for now * Add test * Quality * Try this way * Debug * Maybe by setting the path? * Let's try another way * It should go better when actually passing the arg... * Remove debug statements and style * Fix config * Add tests * Test require the three backends * intermediate commit * Revamp pattern replacements and start work on feature extractors * Adapt model info * Finalize code for processors * Fix in main init additions * Finish questionnaire for processing classes * Fix file name * Fix for real * Fix patterns * Style * Remove needless warnings * Copied from should work now. * Include Copied form in blocks * Add test * More fixes and tests * Apply suggestions from code review Co-authored-by: Lysandre Debut * Address review comment Co-authored-by: Lysandre Debut --- .github/workflows/add-model-like.yml | 60 + .../commands/add_new_model_like.py | 1478 +++++++++++++++++ src/transformers/commands/transformers_cli.py | 2 + .../fixtures/add_distilbert_like_config.json | 19 + tests/test_add_new_model_like.py | 1342 +++++++++++++++ utils/tests_fetcher.py | 1 + 6 files changed, 2902 insertions(+) create mode 100644 .github/workflows/add-model-like.yml create mode 100644 src/transformers/commands/add_new_model_like.py create mode 100644 tests/fixtures/add_distilbert_like_config.json create mode 100644 tests/test_add_new_model_like.py diff --git a/.github/workflows/add-model-like.yml b/.github/workflows/add-model-like.yml new file mode 100644 index 00000000000..89f6e840aae --- /dev/null +++ b/.github/workflows/add-model-like.yml @@ -0,0 +1,60 @@ +name: Add model like runner + +on: + push: + branches: + - master + pull_request: + paths: + - "src/**" + - "tests/**" + - ".github/**" + types: [opened, synchronize, reopened] + +jobs: + run_tests_templates: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Loading cache. + uses: actions/cache@v2 + id: cache + with: + path: ~/.cache/pip + key: v1-tests_model_like + restore-keys: | + v1-tests_model_like-${{ hashFiles('setup.py') }} + v1-tests_model_like + + - name: Install dependencies + run: | + pip install --upgrade pip!=21.3 + sudo apt -y update && sudo apt install -y libsndfile1-dev + pip install .[dev] + + - name: Create model files + run: | + transformers-cli add-new-model-like --config_file tests/fixtures/add_distilbert_like_config.json --path_to_repo . + make style + make fix-copies + + - name: Run all PyTorch modeling test + run: | + python -m pytest -n 2 --dist=loadfile -s --make-reports=tests_new_models tests/test_modeling_bert_new.py + + - name: Run style changes + run: | + git fetch origin master:master + make style && make quality && make repo-consistency + + - name: Failure short reports + if: ${{ always() }} + run: cat reports/tests_new_models_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: run_all_tests_new_models_test_reports + path: reports diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py new file mode 100644 index 00000000000..e443a235c42 --- /dev/null +++ b/src/transformers/commands/add_new_model_like.py @@ -0,0 +1,1478 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import difflib +import json +import os +import re +from argparse import ArgumentParser, Namespace +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union + +import transformers.models.auto as auto_module +from transformers.models.auto.configuration_auto import model_type_to_module_name + +from ..utils import logging +from . import BaseTransformersCLICommand + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +TRANSFORMERS_PATH = Path(__file__).parent.parent +REPO_PATH = TRANSFORMERS_PATH.parent.parent + + +@dataclass +class ModelPatterns: + """ + Holds the basic information about a new model for the add-new-model-like command. + + Args: + model_name (`str`): The model name. + checkpoint (`str`): The checkpoint to use for doc examples. + model_type (`str`, *optional*): + The model type, the identifier used internally in the library like `bert` or `xlm-roberta`. Will default to + `model_name` lowercased with spaces replaced with minuses (-). + model_lower_cased (`str`, *optional*): + The lowercased version of the model name, to use for the module name or function names. Will default to + `model_name` lowercased with spaces and minuses replaced with underscores. + model_camel_cased (`str`, *optional*): + The camel-cased version of the model name, to use for the class names. Will default to `model_name` + camel-cased (with spaces and minuses both considered as word separators. + model_upper_cased (`str`, *optional*): + The uppercased version of the model name, to use for the constant names. Will default to `model_name` + uppercased with spaces and minuses replaced with underscores. + config_class (`str`, *optional*): + The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`. + tokenizer_class (`str`, *optional*): + The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer). + feature_extractor_class (`str`, *optional*): + The feature extractor class associated with this model (leave to `None` for models that don't use a feature + extractor). + processor_class (`str`, *optional*): + The processor class associated with this model (leave to `None` for models that don't use a processor). + """ + + model_name: str + checkpoint: str + model_type: Optional[str] = None + model_lower_cased: Optional[str] = None + model_camel_cased: Optional[str] = None + model_upper_cased: Optional[str] = None + config_class: Optional[str] = None + tokenizer_class: Optional[str] = None + feature_extractor_class: Optional[str] = None + processor_class: Optional[str] = None + + def __post_init__(self): + if self.model_type is None: + self.model_type = self.model_name.lower().replace(" ", "-") + if self.model_lower_cased is None: + self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_") + if self.model_camel_cased is None: + # Split the model name on - and space + words = self.model_name.split(" ") + words = list(chain(*[w.split("-") for w in words])) + # Make sure each word is capitalized + words = [w[0].upper() + w[1:] for w in words] + self.model_camel_cased = "".join(words) + if self.model_upper_cased is None: + self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_") + if self.config_class is None: + self.config_class = f"{self.model_camel_cased}Config" + + +ATTRIBUTE_TO_PLACEHOLDER = { + "config_class": "[CONFIG_CLASS]", + "tokenizer_class": "[TOKENIZER_CLASS]", + "feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]", + "processor_class": "[PROCESSOR_CLASS]", + "checkpoint": "[CHECKPOINT]", + "model_type": "[MODEL_TYPE]", + "model_upper_cased": "[MODEL_UPPER_CASED]", + "model_camel_cased": "[MODEL_CAMELCASED]", + "model_lower_cased": "[MODEL_LOWER_CASED]", + "model_name": "[MODEL_NAME]", +} + + +def is_empty_line(line: str) -> bool: + """ + Determines whether a line is empty or not. + """ + return len(line) == 0 or line.isspace() + + +def find_indent(line: str) -> int: + """ + Returns the number of spaces that start a line indent. + """ + search = re.search("^(\s*)(?:\S|$)", line) + if search is None: + return 0 + return len(search.groups()[0]) + + +def parse_module_content(content: str) -> List[str]: + """ + Parse the content of a module in the list of objects it defines. + + Args: + content (`str`): The content to parse + + Returns: + `List[str]`: The list of objects defined in the module. + """ + objects = [] + current_object = [] + lines = content.split("\n") + # Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this. + end_markers = [")", "]", "}", '"""'] + + for line in lines: + # End of an object + is_valid_object = len(current_object) > 0 + if is_valid_object and len(current_object) == 1: + is_valid_object = not current_object[0].startswith("# Copied from") + if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object: + # Closing parts should be included in current object + if line in end_markers: + current_object.append(line) + objects.append("\n".join(current_object)) + current_object = [] + else: + objects.append("\n".join(current_object)) + current_object = [line] + else: + current_object.append(line) + + # Add last object + if len(current_object) > 0: + objects.append("\n".join(current_object)) + + return objects + + +def add_content_to_text( + text: str, + content: str, + add_after: Optional[Union[str, Pattern]] = None, + add_before: Optional[Union[str, Pattern]] = None, + exact_match: bool = False, +) -> str: + """ + A utility to add some content inside a given text. + + Args: + text (`str`): The text in which we want to insert some content. + content (`str`): The content to add. + add_after (`str` or `Pattern`): + The pattern to test on a line of `text`, the new content is added after the first instance matching it. + add_before (`str` or `Pattern`): + The pattern to test on a line of `text`, the new content is added before the first instance matching it. + exact_match (`bool`, *optional*, defaults to `False`): + A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`, + otherwise, if `add_after`/`add_before` is present in the line. + + + + The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided. + + + + Returns: + `str`: The text with the new content added if a match was found. + """ + if add_after is None and add_before is None: + raise ValueError("You need to pass either `add_after` or `add_before`") + if add_after is not None and add_before is not None: + raise ValueError("You can't pass both `add_after` or `add_before`") + pattern = add_after if add_before is None else add_before + + def this_is_the_line(line): + if isinstance(pattern, Pattern): + return pattern.search(line) is not None + elif exact_match: + return pattern == line + else: + return pattern in line + + new_lines = [] + for line in text.split("\n"): + if this_is_the_line(line): + if add_before is not None: + new_lines.append(content) + new_lines.append(line) + if add_after is not None: + new_lines.append(content) + else: + new_lines.append(line) + + return "\n".join(new_lines) + + +def add_content_to_file( + file_name: Union[str, os.PathLike], + content: str, + add_after: Optional[Union[str, Pattern]] = None, + add_before: Optional[Union[str, Pattern]] = None, + exact_match: bool = False, +): + """ + A utility to add some content inside a given file. + + Args: + file_name (`str` or `os.PathLike`): The name of the file in which we want to insert some content. + content (`str`): The content to add. + add_after (`str` or `Pattern`): + The pattern to test on a line of `text`, the new content is added after the first instance matching it. + add_before (`str` or `Pattern`): + The pattern to test on a line of `text`, the new content is added before the first instance matching it. + exact_match (`bool`, *optional*, defaults to `False`): + A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`, + otherwise, if `add_after`/`add_before` is present in the line. + + + + The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided. + + + """ + with open(file_name, "r", encoding="utf-8") as f: + old_content = f.read() + + new_content = add_content_to_text( + old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match + ) + + with open(file_name, "w", encoding="utf-8") as f: + f.write(new_content) + + +def replace_model_patterns( + text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns +) -> Tuple[str, str]: + """ + Replace all patterns present in a given text. + + Args: + text (`str`): The text to treat. + old_model_patterns (`ModelPatterns`): The patterns for the old model. + new_model_patterns (`ModelPatterns`): The patterns for the new model. + + Returns: + `Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it. + """ + # The order is crucially important as we will check and replace in that order. For instance the config probably + # contains the camel-cased named, but will be treated before. + attributes_to_check = ["config_class"] + # Add relevant preprocessing classes + for attr in ["tokenizer_class", "feature_extractor_class", "processor_class"]: + if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None: + attributes_to_check.append(attr) + + # Special cases for checkpoint and model_type + if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]: + attributes_to_check.append("checkpoint") + if old_model_patterns.model_type != old_model_patterns.model_lower_cased: + attributes_to_check.append("model_type") + else: + text = re.sub( + fr'(\s*)model_type = "{old_model_patterns.model_type}"', + r'\1model_type = "[MODEL_TYPE]"', + text, + ) + + # Special case when the model camel cased and upper cased names are the same for the old model (like for GPT2) but + # not the new one. We can't just do a replace in all the text and will need a special regex + if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased: + old_model_value = old_model_patterns.model_upper_cased + if re.search(fr"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None: + text = re.sub(fr"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text) + else: + attributes_to_check.append("model_upper_cased") + + attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"]) + + # Now let's replace every other attribute by their placeholder + for attr in attributes_to_check: + text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr]) + + # Finally we can replace the placeholder byt the new values. + replacements = [] + for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items(): + if placeholder in text: + replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr))) + text = text.replace(placeholder, getattr(new_model_patterns, attr)) + + # If we have two inconsistent replacements, we don't return anything (ex: GPT2->GPT_NEW and GPT2->GPTNew) + old_replacement_values = [old for old, new in replacements] + if len(set(old_replacement_values)) != len(old_replacement_values): + return text, "" + + replacements = simplify_replacements(replacements) + replacements = [f"{old}->{new}" for old, new in replacements] + return text, ",".join(replacements) + + +def simplify_replacements(replacements): + """ + Simplify a list of replacement patterns to make sure there are no needless ones. + + For instance in the sequence "Bert->BertNew, BertConfig->BertNewConfig, bert->bert_new", the replacement + "BertConfig->BertNewConfig" is implied by "Bert->BertNew" so not needed. + + Args: + replacements (`List[Tuple[str, str]]`): List of patterns (old, new) + + Returns: + `List[Tuple[str, str]]`: The list of patterns simplified. + """ + if len(replacements) <= 1: + # Nothing to simplify + return replacements + + # Next let's sort replacements by length as a replacement can only "imply" another replacement if it's shorter. + replacements.sort(key=lambda x: len(x[0])) + + idx = 0 + while idx < len(replacements): + old, new = replacements[idx] + # Loop through all replacements after + j = idx + 1 + while j < len(replacements): + old_2, new_2 = replacements[j] + # If the replacement is implied by the current one, we can drop it. + if old_2.replace(old, new) == new_2: + replacements.pop(j) + else: + j += 1 + idx += 1 + + return replacements + + +def get_module_from_file(module_file: Union[str, os.PathLike]) -> str: + """ + Returns the module name corresponding to a module file. + """ + full_module_path = Path(module_file).absolute() + module_parts = full_module_path.with_suffix("").parts + + # Find the first part named transformers, starting from the end. + idx = len(module_parts) - 1 + while idx >= 0 and module_parts[idx] != "transformers": + idx -= 1 + if idx < 0: + raise ValueError(f"{module_file} is not a transformers module.") + + return ".".join(module_parts[idx:]) + + +SPECIAL_PATTERNS = { + "_CHECKPOINT_FOR_DOC =": "checkpoint", + "_CONFIG_FOR_DOC =": "config_class", + "_TOKENIZER_FOR_DOC =": "tokenizer_class", + "_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class", + "_PROCESSOR_FOR_DOC =": "processor_class", +} + + +_re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE) + + +def duplicate_module( + module_file: Union[str, os.PathLike], + old_model_patterns: ModelPatterns, + new_model_patterns: ModelPatterns, + dest_file: Optional[str] = None, + add_copied_from: bool = True, +): + """ + Create a new module from an existing one and adapting all function and classes names from old patterns to new ones. + + Args: + module_file (`str` or `os.PathLike`): Path to the module to duplicate. + old_model_patterns (`ModelPatterns`): The patterns for the old model. + new_model_patterns (`ModelPatterns`): The patterns for the new model. + dest_file (`str` or `os.PathLike`, *optional*): Path to the new module. + add_copied_from (`bool`, *optional*, defaults to `True`): + Whether or not to add `# Copied from` statements in the duplicated module. + """ + if dest_file is None: + dest_file = str(module_file).replace( + old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased + ) + + with open(module_file, "r", encoding="utf-8") as f: + content = f.read() + + objects = parse_module_content(content) + + # Loop and treat all objects + new_objects = [] + for obj in objects: + # Special cases + if "PRETRAINED_CONFIG_ARCHIVE_MAP = {" in obj: + # docstyle-ignore + obj = ( + f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP = " + + "{" + + f""" + "{new_model_patterns.checkpoint}": "https://huggingface.co/{new_model_patterns.checkpoint}/resolve/main/config.json", +""" + + "}\n" + ) + new_objects.append(obj) + continue + elif "PRETRAINED_MODEL_ARCHIVE_LIST = [" in obj: + if obj.startswith("TF_"): + prefix = "TF_" + elif obj.startswith("FLAX_"): + prefix = "FLAX_" + else: + prefix = "" + # docstyle-ignore + obj = f"""{prefix}{new_model_patterns.model_upper_cased}_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "{new_model_patterns.checkpoint}", + # See all {new_model_patterns.model_name} models at https://huggingface.co/models?filter={new_model_patterns.model_type} +] +""" + new_objects.append(obj) + continue + + special_pattern = False + for pattern, attr in SPECIAL_PATTERNS.items(): + if pattern in obj: + obj = obj.replace(getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)) + new_objects.append(obj) + special_pattern = True + break + + if special_pattern: + continue + + # Regular classes functions + old_obj = obj + obj, replacement = replace_model_patterns(obj, old_model_patterns, new_model_patterns) + has_copied_from = re.search("^#\s+Copied from", obj, flags=re.MULTILINE) is not None + if add_copied_from and not has_copied_from and _re_class_func.search(obj) is not None and len(replacement) > 0: + # Copied from statement must be added just before the class/function definition, which may not be the + # first line because of decorators. + module_name = get_module_from_file(module_file) + old_object_name = _re_class_func.search(old_obj).groups()[0] + obj = add_content_to_text( + obj, f"# Copied from {module_name}.{old_object_name} with {replacement}", add_before=_re_class_func + ) + # In all cases, we remove Copied from statement with indent on methods. + obj = re.sub("\n[ ]+# Copied from [^\n]*\n", "\n", obj) + + new_objects.append(obj) + + with open(dest_file, "w", encoding="utf-8") as f: + content = f.write("\n".join(new_objects)) + + +def filter_framework_files( + files: List[Union[str, os.PathLike]], frameworks: Optional[List[str]] = None +) -> List[Union[str, os.PathLike]]: + """ + Filter a list of files to only keep the ones corresponding to a list of frameworks. + + Args: + files (`List[Union[str, os.PathLike]]`): The list of files to filter. + frameworks (`List[str]`, *optional*): The list of allowed frameworks. + + Returns: + `List[Union[str, os.PathLike]]`: The list of filtered files. + """ + if frameworks is None: + return files + + framework_to_file = {} + others = [] + for f in files: + parts = Path(f).name.split("_") + if "modeling" not in parts: + others.append(f) + continue + if "tf" in parts: + framework_to_file["tf"] = f + elif "flax" in parts: + framework_to_file["flax"] = f + else: + framework_to_file["pt"] = f + + return [framework_to_file[f] for f in frameworks] + others + + +def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, Union[Path, List[Path]]]: + """ + Retrieves all the files associated to a model. + + Args: + model_type (`str`): A valid model type (like "bert" or "gpt2") + frameworks (`List[str]`, *optional*): + If passed, will only keep the model files corresponding to the passed frameworks. + + Returns: + `Dict[str, Union[Path, List[Path]]]`: A dictionary with the following keys: + - **doc_file** -- The documentation file for the model. + - **model_files** -- All the files in the model module. + - **test_files** -- The test files for the model. + """ + module_name = model_type_to_module_name(model_type) + + model_module = TRANSFORMERS_PATH / "models" / module_name + model_files = list(model_module.glob("*.py")) + model_files = filter_framework_files(model_files, frameworks=frameworks) + + doc_file = REPO_PATH / "docs" / "source" / "model_doc" / f"{model_type}.mdx" + + # Basic pattern for test files + test_files = [ + f"test_modeling_{module_name}.py", + f"test_modeling_tf_{module_name}.py", + f"test_modeling_flax_{module_name}.py", + f"test_tokenization_{module_name}.py", + f"test_feature_extraction_{module_name}.py", + f"test_processor_{module_name}.py", + ] + test_files = filter_framework_files(test_files, frameworks=frameworks) + # Add the test directory + test_files = [REPO_PATH / "tests" / f for f in test_files] + # Filter by existing files + test_files = [f for f in test_files if f.exists()] + + return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files} + + +_re_checkpoint_for_doc = re.compile("^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE) + + +def find_base_model_checkpoint( + model_type: str, model_files: Optional[Dict[str, Union[Path, List[Path]]]] = None +) -> str: + """ + Finds the model checkpoint used in the docstrings for a given model. + + Args: + model_type (`str`): A valid model type (like "bert" or "gpt2") + model_files (`Dict[str, Union[Path, List[Path]]`, *optional*): + The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed. + + Returns: + `str`: The checkpoint used. + """ + if model_files is None: + model_files = get_model_files(model_type) + module_files = model_files["model_files"] + for fname in module_files: + if "modeling" not in str(fname): + continue + + with open(fname, "r", encoding="utf-8") as f: + content = f.read() + if _re_checkpoint_for_doc.search(content) is not None: + checkpoint = _re_checkpoint_for_doc.search(content).groups()[0] + # Remove quotes + checkpoint = checkpoint.replace('"', "") + checkpoint = checkpoint.replace("'", "") + return checkpoint + + # TODO: Find some kind of fallback if there is no _CHECKPOINT_FOR_DOC in any of the modeling file. + return "" + + +_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES") + + +def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, List[str]]: + """ + Retrieve the model classes associated to a given model. + + Args: + model_type (`str`): A valid model type (like "bert" or "gpt2") + frameworks (`List[str]`, *optional*): + The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict + the classes returned. + + Returns: + `Dict[str, List[str]]`: A dictionary with one key per framework and the list of model classes associated to + that framework as values. + """ + if frameworks is None: + frameworks = ["pt", "tf", "flax"] + + modules = { + "pt": auto_module.modeling_auto, + "tf": auto_module.modeling_tf_auto, + "flax": auto_module.modeling_flax_auto, + } + + model_classes = {} + for framework in frameworks: + new_model_classes = [] + model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None] + for model_mapping_name in model_mappings: + model_mapping = getattr(modules[framework], model_mapping_name) + if model_type in model_mapping: + new_model_classes.append(model_mapping[model_type]) + + if len(new_model_classes) > 0: + # Remove duplicates + model_classes[framework] = list(set(new_model_classes)) + + return model_classes + + +def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None): + """ + Retrieves all the information from a given model_type. + + Args: + model_type (`str`): A valid model type (like "bert" or "gpt2") + frameworks (`List[str]`, *optional*): + If passed, will only keep the info corresponding to the passed frameworks. + + Returns: + `Dict`: A dictionary with the following keys: + - **frameworks** (`List[str]`): The list of frameworks that back this model type. + - **model_classes** (`Dict[str, List[str]]`): The model classes implemented for that model type. + - **model_files** (`Dict[str, Union[Path, List[Path]]]`): The files associated with that model type. + - **model_patterns** (`ModelPatterns`): The various patterns for the model. + """ + if model_type not in auto_module.MODEL_NAMES_MAPPING: + raise ValueError(f"{model_type} is not a valid model type.") + + model_name = auto_module.MODEL_NAMES_MAPPING[model_type] + config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type] + archive_map = auto_module.configuration_auto.CONFIG_ARCHIVE_MAP_MAPPING_NAMES.get(model_type, None) + if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES: + tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type] + tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1] + else: + tokenizer_class = None + feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None) + processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None) + + model_files = get_model_files(model_type, frameworks=frameworks) + model_camel_cased = config_class.replace("Config", "") + + available_frameworks = [] + for fname in model_files["model_files"]: + if "modeling_tf" in str(fname): + available_frameworks.append("tf") + elif "modeling_flax" in str(fname): + available_frameworks.append("flax") + elif "modeling" in str(fname): + available_frameworks.append("pt") + + if frameworks is None: + frameworks = available_frameworks.copy() + else: + frameworks = [f for f in frameworks if f in available_frameworks] + + model_classes = retrieve_model_classes(model_type, frameworks=frameworks) + + # Retrieve model upper-cased name from the constant name of the pretrained archive map. + if archive_map is None: + model_upper_cased = model_camel_cased.upper() + else: + parts = archive_map.split("_") + idx = 0 + while idx < len(parts) and parts[idx] != "PRETRAINED": + idx += 1 + if idx < len(parts): + model_upper_cased = "_".join(parts[:idx]) + else: + model_upper_cased = model_camel_cased.upper() + + model_patterns = ModelPatterns( + model_name, + checkpoint=find_base_model_checkpoint(model_type, model_files=model_files), + model_type=model_type, + model_camel_cased=model_camel_cased, + model_lower_cased=model_files["module_name"], + model_upper_cased=model_upper_cased, + config_class=config_class, + tokenizer_class=tokenizer_class, + feature_extractor_class=feature_extractor_class, + processor_class=processor_class, + ) + + return { + "frameworks": frameworks, + "model_classes": model_classes, + "model_files": model_files, + "model_patterns": model_patterns, + } + + +def clean_frameworks_in_init( + init_file: Union[str, os.PathLike], frameworks: Optional[List[str]] = None, keep_processing: bool = True +): + """ + Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature + extractors/processors in an init. + + Args: + init_file (`str` or `os.PathLike`): The path to the init to treat. + frameworks (`List[str]`, *optional*): + If passed, this will remove all imports that are subject to a framework not in frameworks + keep_processing (`bool`, *optional*, defaults to `True`): + Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init. + """ + if frameworks is None: + frameworks = ["pt", "tf", "flax"] + + names = {"pt": "torch"} + to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks] + if not keep_processing: + to_remove.extend(["sentencepiece", "tokenizers", "vision"]) + + if len(to_remove) == 0: + # Nothing to do + return + + remove_pattern = "|".join(to_remove) + re_conditional_imports = re.compile(fr"^\s*if is_({remove_pattern})_available\(\):\s*$") + re_is_xxx_available = re.compile(fr"is_({remove_pattern})_available") + + with open(init_file, "r", encoding="utf-8") as f: + content = f.read() + + lines = content.split("\n") + new_lines = [] + idx = 0 + while idx < len(lines): + # Conditional imports + if re_conditional_imports.search(lines[idx]) is not None: + idx += 1 + while is_empty_line(lines[idx]): + idx += 1 + indent = find_indent(lines[idx]) + while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]): + idx += 1 + # Remove the import from file_utils + elif re_is_xxx_available.search(lines[idx]) is not None: + line = lines[idx] + for framework in to_remove: + line = line.replace(f", is_{framework}_available", "") + line = line.replace(f"is_{framework}_available, ", "") + line = line.replace(f"is_{framework}_available", "") + + if len(line.strip()) > 0: + new_lines.append(line) + idx += 1 + # Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it. + elif keep_processing or ( + re.search('^\s*"(tokenization|processing|feature_extraction)', lines[idx]) is None + and re.search("^\s*from .(tokenization|processing|feature_extraction)", lines[idx]) is None + ): + new_lines.append(lines[idx]) + idx += 1 + else: + idx += 1 + + with open(init_file, "w", encoding="utf-8") as f: + f.write("\n".join(new_lines)) + + +def add_model_to_main_init( + old_model_patterns: ModelPatterns, + new_model_patterns: ModelPatterns, + frameworks: Optional[List[str]] = None, + with_processing: bool = True, +): + """ + Add a model to the main init of Transformers. + + Args: + old_model_patterns (`ModelPatterns`): The patterns for the old model. + new_model_patterns (`ModelPatterns`): The patterns for the new model. + frameworks (`List[str]`, *optional*): + If specified, only the models implemented in those frameworks will be added. + with_processsing (`bool`, *optional*, defaults to `True`): + Whether the tokenizer/feature extractor/processor of the model should also be added to the init or not. + """ + with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f: + content = f.read() + + lines = content.split("\n") + idx = 0 + new_lines = [] + framework = None + while idx < len(lines): + if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0: + framework = None + elif lines[idx].lstrip().startswith("if is_torch_available"): + framework = "pt" + elif lines[idx].lstrip().startswith("if is_tf_available"): + framework = "tf" + elif lines[idx].lstrip().startswith("if is_flax_available"): + framework = "flax" + + # Skip if we are in a framework not wanted. + if framework is not None and frameworks is not None and framework not in frameworks: + new_lines.append(lines[idx]) + idx += 1 + elif re.search(fr'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None: + block = [lines[idx]] + indent = find_indent(lines[idx]) + idx += 1 + while find_indent(lines[idx]) > indent: + block.append(lines[idx]) + idx += 1 + if lines[idx].strip() in [")", "]", "],"]: + block.append(lines[idx]) + idx += 1 + block = "\n".join(block) + new_lines.append(block) + + add_block = True + if not with_processing: + processing_classes = [ + old_model_patterns.tokenizer_class, + old_model_patterns.feature_extractor_class, + old_model_patterns.processor_class, + ] + # Only keep the ones that are not None + processing_classes = [c for c in processing_classes if c is not None] + for processing_class in processing_classes: + block = block.replace(f' "{processing_class}",', "") + block = block.replace(f', "{processing_class}"', "") + block = block.replace(f" {processing_class},", "") + block = block.replace(f", {processing_class}", "") + + if processing_class in block: + add_block = False + if add_block: + new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0]) + else: + new_lines.append(lines[idx]) + idx += 1 + + with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f: + f.write("\n".join(new_lines)) + + +def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns): + """ + Add a tokenizer to the relevant mappings in the auto module. + + Args: + old_model_patterns (`ModelPatterns`): The patterns for the old model. + new_model_patterns (`ModelPatterns`): The patterns for the new model. + """ + if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None: + return + + with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f: + content = f.read() + + lines = content.split("\n") + idx = 0 + # First we get to the TOKENIZER_MAPPING_NAMES block. + while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("): + idx += 1 + idx += 1 + + # That block will end at this prompt: + while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"): + # Either all the tokenizer block is defined on one line, in which case, it ends with ")," + if lines[idx].endswith(","): + block = lines[idx] + # Otherwise it takes several lines until we get to a ")," + else: + block = [] + while not lines[idx].startswith(" ),"): + block.append(lines[idx]) + idx += 1 + block = "\n".join(block) + idx += 1 + + # If we find the model type and tokenizer class in that block, we have the old model tokenizer block + if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block: + break + + new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type) + new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class) + + new_lines = lines[:idx] + [new_block] + lines[idx:] + with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f: + f.write("\n".join(new_lines)) + + +AUTO_CLASSES_PATTERNS = { + "configuration_auto.py": [ + ' ("{model_type}", "{model_name}"),', + ' ("{model_type}", "{config_class}"),', + ' ("{model_type}", "{pretrained_archive_map}"),', + ], + "feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'], + "modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'], + "modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'], + "modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'], + "processing_auto.py": [' ("{model_type}", "{processor_class}"),'], +} + + +def add_model_to_auto_classes( + old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: Dict[str, List[str]] +): + """ + Add a model to the relevant mappings in the auto module. + + Args: + old_model_patterns (`ModelPatterns`): The patterns for the old model. + new_model_patterns (`ModelPatterns`): The patterns for the new model. + model_classes (`Dict[str, List[str]]`): A dictionary framework to list of model classes implemented. + """ + for filename in AUTO_CLASSES_PATTERNS: + # Extend patterns with all model classes if necessary + new_patterns = [] + for pattern in AUTO_CLASSES_PATTERNS[filename]: + if re.search("any_([a-z]*)_class", pattern) is not None: + framework = re.search("any_([a-z]*)_class", pattern).groups()[0] + if framework in model_classes: + new_patterns.extend( + [ + pattern.replace("{" + f"any_{framework}_class" + "}", cls) + for cls in model_classes[framework] + ] + ) + elif "{config_class}" in pattern: + new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class)) + elif "{feature_extractor_class}" in pattern: + if ( + old_model_patterns.feature_extractor_class is not None + and new_model_patterns.feature_extractor_class is not None + ): + new_patterns.append( + pattern.replace("{feature_extractor_class}", old_model_patterns.feature_extractor_class) + ) + elif "{processor_class}" in pattern: + if old_model_patterns.processor_class is not None and new_model_patterns.processor_class is not None: + new_patterns.append(pattern.replace("{processor_class}", old_model_patterns.processor_class)) + else: + new_patterns.append(pattern) + + # Loop through all patterns. + for pattern in new_patterns: + full_name = TRANSFORMERS_PATH / "models" / "auto" / filename + old_model_line = pattern + new_model_line = pattern + for attr in ["model_type", "model_name"]: + old_model_line = old_model_line.replace("{" + attr + "}", getattr(old_model_patterns, attr)) + new_model_line = new_model_line.replace("{" + attr + "}", getattr(new_model_patterns, attr)) + if "pretrained_archive_map" in pattern: + old_model_line = old_model_line.replace( + "{pretrained_archive_map}", f"{old_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP" + ) + new_model_line = new_model_line.replace( + "{pretrained_archive_map}", f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP" + ) + + new_model_line = new_model_line.replace( + old_model_patterns.model_camel_cased, new_model_patterns.model_camel_cased + ) + + add_content_to_file(full_name, new_model_line, add_after=old_model_line) + + # Tokenizers require special handling + insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns) + + +DOC_OVERVIEW_TEMPLATE = """## Overview + +The {model_name} model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + +""" + + +def duplicate_doc_file( + doc_file: Union[str, os.PathLike], + old_model_patterns: ModelPatterns, + new_model_patterns: ModelPatterns, + dest_file: Optional[Union[str, os.PathLike]] = None, + frameworks: Optional[List[str]] = None, +): + """ + Duplicate a documentation file and adapts it for a new model. + + Args: + module_file (`str` or `os.PathLike`): Path to the doc file to duplicate. + old_model_patterns (`ModelPatterns`): The patterns for the old model. + new_model_patterns (`ModelPatterns`): The patterns for the new model. + dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file. + Will default to the a file named `{new_model_patterns.model_type}.mdx` in the same folder as `module_file`. + frameworks (`List[str]`, *optional*): + If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file. + """ + with open(doc_file, "r", encoding="utf-8") as f: + content = f.read() + + if frameworks is None: + frameworks = ["pt", "tf", "flax"] + if dest_file is None: + dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx" + + # Parse the doc file in blocks. One block per section/header + lines = content.split("\n") + blocks = [] + current_block = [] + + for line in lines: + if line.startswith("#"): + blocks.append("\n".join(current_block)) + current_block = [line] + else: + current_block.append(line) + blocks.append("\n".join(current_block)) + + new_blocks = [] + in_classes = False + for block in blocks: + # Copyright + if not block.startswith("#"): + new_blocks.append(block) + # Main title + elif re.search("^#\s+\S+", block) is not None: + new_blocks.append(f"# {new_model_patterns.model_name}\n") + # The config starts the part of the doc with the classes. + elif not in_classes and old_model_patterns.config_class in block.split("\n")[0]: + in_classes = True + new_blocks.append(DOC_OVERVIEW_TEMPLATE.format(model_name=new_model_patterns.model_name)) + new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns) + new_blocks.append(new_block) + # In classes + elif in_classes: + in_classes = True + block_title = block.split("\n")[0] + block_class = re.search("^#+\s+(\S.*)$", block_title).groups()[0] + new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns) + + if "Tokenizer" in block_class: + # We only add the tokenizer if necessary + if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class: + new_blocks.append(new_block) + elif "FeatureExtractor" in block_class: + # We only add the feature extractor if necessary + if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class: + new_blocks.append(new_block) + elif "Processor" in block_class: + # We only add the processor if necessary + if old_model_patterns.processor_class != new_model_patterns.processor_class: + new_blocks.append(new_block) + elif block_class.startswith("Flax"): + # We only add Flax models if in the selected frameworks + if "flax" in frameworks: + new_blocks.append(new_block) + elif block_class.startswith("TF"): + # We only add TF models if in the selected frameworks + if "tf" in frameworks: + new_blocks.append(new_block) + elif len(block_class.split(" ")) == 1: + # We only add PyTorch models if in the selected frameworks + if "pt" in frameworks: + new_blocks.append(new_block) + else: + new_blocks.append(new_block) + + with open(dest_file, "w", encoding="utf-8") as f: + f.write("\n".join(new_blocks)) + + +def create_new_model_like( + model_type: str, + new_model_patterns: ModelPatterns, + add_copied_from: bool = True, + frameworks: Optional[List[str]] = None, +): + """ + Creates a new model module like a given model of the Transformers library. + + Args: + model_type (`str`): The model type to duplicate (like "bert" or "gpt2") + new_model_patterns (`ModelPatterns`): The patterns for the new model. + add_copied_from (`bool`, *optional*, defaults to `True`): + Whether or not to add "Copied from" statements to all classes in the new model modeling files. + frameworks (`List[str]`, *optional*): + If passed, will limit the duplicate to the frameworks specified. + """ + # Retrieve all the old model info. + model_info = retrieve_info_for_model(model_type, frameworks=frameworks) + model_files = model_info["model_files"] + old_model_patterns = model_info["model_patterns"] + keep_old_processing = True + for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]: + if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr): + keep_old_processing = False + + model_classes = model_info["model_classes"] + + # 1. We create the module for our new model. + old_module_name = model_files["module_name"] + module_folder = TRANSFORMERS_PATH / "models" / new_model_patterns.model_lower_cased + os.makedirs(module_folder, exist_ok=True) + + files_to_adapt = model_files["model_files"] + if keep_old_processing: + files_to_adapt = [ + f + for f in files_to_adapt + if "tokenization" not in str(f) and "processing" not in str(f) and "feature_extraction" not in str(f) + ] + + os.makedirs(module_folder, exist_ok=True) + for module_file in files_to_adapt: + new_module_name = module_file.name.replace( + old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased + ) + dest_file = module_folder / new_module_name + duplicate_module( + module_file, + old_model_patterns, + new_model_patterns, + dest_file=dest_file, + add_copied_from=add_copied_from and "modeling" in new_module_name, + ) + + clean_frameworks_in_init( + module_folder / "__init__.py", frameworks=frameworks, keep_processing=not keep_old_processing + ) + + # 2. We add our new model to the models init and the main init + add_content_to_file( + TRANSFORMERS_PATH / "models" / "__init__.py", + f" {new_model_patterns.model_lower_cased},", + add_after=f" {old_module_name},", + exact_match=True, + ) + add_model_to_main_init( + old_model_patterns, new_model_patterns, frameworks=frameworks, with_processing=not keep_old_processing + ) + + # 3. Add test files + files_to_adapt = model_files["test_files"] + if keep_old_processing: + files_to_adapt = [ + f + for f in files_to_adapt + if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) + ] + + for test_file in files_to_adapt: + new_test_file_name = test_file.name.replace( + old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased + ) + dest_file = test_file.parent / new_test_file_name + duplicate_module( + test_file, + old_model_patterns, + new_model_patterns, + dest_file=dest_file, + add_copied_from=False, + ) + + # 4. Add model to auto classes + add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes) + + # 5. Add doc file + doc_file = REPO_PATH / "docs" / "source" / "model_doc" / f"{old_model_patterns.model_type}.mdx" + duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks) + + # 6. Warn the user for duplicate patterns + if old_model_patterns.model_type == old_model_patterns.checkpoint: + print( + "The model you picked has the same name for the model type and the checkpoint name " + f"({old_model_patterns.model_type}). As a result, it's possible some places where the new checkpoint " + f"should be, you have {new_model_patterns.model_type} instead. You should search for all instances of " + f"{new_model_patterns.model_type} in the new files and check they're not badly used as checkpoints." + ) + elif old_model_patterns.model_lower_cased == old_model_patterns.checkpoint: + print( + "The model you picked has the same name for the model type and the checkpoint name " + f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new " + f"checkpoint should be, you have {new_model_patterns.model_lower_cased} instead. You should search for " + f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly " + "used as checkpoints." + ) + if ( + old_model_patterns.model_type == old_model_patterns.model_lower_cased + and new_model_patterns.model_type != new_model_patterns.model_lower_cased + ): + print( + "The model you picked has the same name for the model type and the lowercased model name " + f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new " + f"model type should be, you have {new_model_patterns.model_lower_cased} instead. You should search for " + f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly " + "used as the model type." + ) + + if not keep_old_processing and old_model_patterns.tokenizer_class is not None: + print( + "The constants at the start of the new tokenizer file created needs to be manually fixed. If your new " + "model has a tokenizer fast, you will also need to manually add the converter in the " + "`SLOW_TO_FAST_CONVERTERS` constant of `convert_slow_tokenizer.py`." + ) + + +def add_new_model_like_command_factory(args: Namespace): + return AddNewModelLikeCommand(config_file=args.config_file, path_to_repo=args.path_to_repo) + + +class AddNewModelLikeCommand(BaseTransformersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + add_new_model_like_parser = parser.add_parser("add-new-model-like") + add_new_model_like_parser.add_argument( + "--config_file", type=str, help="A file with all the information for this model creation." + ) + add_new_model_like_parser.add_argument( + "--path_to_repo", type=str, help="When not using an editable install, the path to the Transformers repo." + ) + add_new_model_like_parser.set_defaults(func=add_new_model_like_command_factory) + + def __init__(self, config_file=None, path_to_repo=None, *args): + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + self.old_model_type = config["old_model_type"] + self.model_patterns = ModelPatterns(**config["new_model_patterns"]) + self.add_copied_from = config.get("add_copied_from", True) + self.frameworks = config.get("frameworks", ["pt", "tf", "flax"]) + else: + self.old_model_type, self.model_patterns, self.add_copied_from, self.frameworks = get_user_input() + + self.path_to_repo = path_to_repo + + def run(self): + if self.path_to_repo is not None: + # Adapt constants + global TRANSFORMERS_PATH + global REPO_PATH + + REPO_PATH = Path(self.path_to_repo) + TRANSFORMERS_PATH = REPO_PATH / "src" / "transformers" + + create_new_model_like( + model_type=self.old_model_type, + new_model_patterns=self.model_patterns, + add_copied_from=self.add_copied_from, + frameworks=self.frameworks, + ) + + +def get_user_field( + question: str, + default_value: Optional[str] = None, + is_valid_answer: Optional[Callable] = None, + convert_to: Optional[Callable] = None, + fallback_message: Optional[str] = None, +) -> Any: + """ + A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid + answer. + + Args: + question (`str`): The question to ask the user. + default_value (`str`, *optional*): A potential default value that will be used when the answer is empty. + is_valid_answer (`Callable`, *optional*): + If set, the question will be asked until this function returns `True` on the provided answer. + convert_to (`Callable`, *optional*): + If set, the answer will be passed to this function. If this function raises an error on the procided + answer, the question will be asked again. + fallback_message (`str`, *optional*): + A message that will be displayed each time the question is asked again to the user. + + Returns: + `Any`: The answer provided by the user (or the default), passed through the potential conversion function. + """ + if not question.endswith(" "): + question = question + " " + if default_value is not None: + question = f"{question} [{default_value}] " + + valid_answer = False + while not valid_answer: + answer = input(question) + if default_value is not None and len(answer) == 0: + answer = default_value + if is_valid_answer is not None: + valid_answer = is_valid_answer(answer) + elif convert_to is not None: + try: + answer = convert_to(answer) + valid_answer = True + except Exception: + valid_answer = False + else: + valid_answer = True + + if not valid_answer: + print(fallback_message) + + return answer + + +def convert_to_bool(x: str) -> bool: + """ + Converts a string to a bool. + """ + if x.lower() in ["1", "y", "yes", "true"]: + return True + if x.lower() in ["0", "n", "no", "false"]: + return False + raise ValueError(f"{x} is not a value that can be converted to a bool.") + + +def get_user_input(): + """ + Ask the user for the necessary inputs to add the new model. + """ + model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys()) + + # Get old model type + valid_model_type = False + while not valid_model_type: + old_model_type = input("What is the model you would like to duplicate? ") + if old_model_type in model_types: + valid_model_type = True + else: + print(f"{old_model_type} is not a valid model type.") + near_choices = difflib.get_close_matches(old_model_type, model_types) + if len(near_choices) >= 1: + if len(near_choices) > 1: + near_choices = " or ".join(near_choices) + print(f"Did you mean {near_choices}?") + + old_model_info = retrieve_info_for_model(old_model_type) + old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class + old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class + old_processor_class = old_model_info["model_patterns"].processor_class + old_frameworks = old_model_info["frameworks"] + + model_name = get_user_field("What is the name for your new model?") + default_patterns = ModelPatterns(model_name, model_name) + + model_type = get_user_field( + "What identifier would you like to use for the model type of this model?", + default_value=default_patterns.model_type, + ) + model_lower_cased = get_user_field( + "What name would you like to use for the module of this model?", + default_value=default_patterns.model_lower_cased, + ) + model_camel_cased = get_user_field( + "What prefix (camel-cased) would you like to use for the model classes of this model?", + default_value=default_patterns.model_camel_cased, + ) + model_upper_cased = get_user_field( + "What prefix (upper-cased) would you like to use for the constants relative to this model?", + default_value=default_patterns.model_upper_cased, + ) + config_class = get_user_field( + "What will be the name of the config class for this model?", default_value=f"{model_camel_cased}Config" + ) + checkpoint = get_user_field("Please give a checkpoint identifier (on the model Hub) for this new model.") + + old_processing_classes = [ + c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None + ] + old_processing_classes = ", ".join(old_processing_classes) + keep_processing = get_user_field( + f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes})?", + convert_to=convert_to_bool, + fallback_message="Please answer yes/no, y/n, true/false or 1/0.", + ) + if keep_processing: + feature_extractor_class = old_feature_extractor_class + processor_class = old_processor_class + tokenizer_class = old_tokenizer_class + else: + if old_tokenizer_class is not None: + tokenizer_class = get_user_field( + "What will be the name of the tokenizer class for this model?", + default_value=f"{model_camel_cased}Tokenizer", + ) + else: + tokenizer_class = None + if old_feature_extractor_class is not None: + feature_extractor_class = get_user_field( + "What will be the name of the feature extractor class for this model?", + default_value=f"{model_camel_cased}FeatureExtractor", + ) + else: + feature_extractor_class = None + if old_processor_class is not None: + processor_class = get_user_field( + "What will be the name of the processor class for this model?", + default_value=f"{model_camel_cased}Processor", + ) + else: + processor_class = None + + model_patterns = ModelPatterns( + model_name, + checkpoint, + model_type=model_type, + model_lower_cased=model_lower_cased, + model_camel_cased=model_camel_cased, + model_upper_cased=model_upper_cased, + config_class=config_class, + tokenizer_class=tokenizer_class, + feature_extractor_class=feature_extractor_class, + processor_class=processor_class, + ) + + add_copied_from = get_user_field( + "Should we add # Copied from statements when creating the new modeling file?", + convert_to=convert_to_bool, + default_value="yes", + fallback_message="Please answer yes/no, y/n, true/false or 1/0.", + ) + + all_frameworks = get_user_field( + f"Should we add a version of your new model in all the frameworks implemented by {old_model_type} ({old_frameworks})?", + convert_to=convert_to_bool, + default_value="yes", + fallback_message="Please answer yes/no, y/n, true/false or 1/0.", + ) + if all_frameworks: + frameworks = None + else: + frameworks = get_user_field( + "Please enter the list of framworks you want (pt, tf, flax) separated by spaces", + is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")), + ) + frameworks = list(set(frameworks.split(" "))) + + return (old_model_type, model_patterns, add_copied_from, frameworks) diff --git a/src/transformers/commands/transformers_cli.py b/src/transformers/commands/transformers_cli.py index d63f6bc9c6e..66c81bb9d09 100644 --- a/src/transformers/commands/transformers_cli.py +++ b/src/transformers/commands/transformers_cli.py @@ -16,6 +16,7 @@ from argparse import ArgumentParser from .add_new_model import AddNewModelCommand +from .add_new_model_like import AddNewModelLikeCommand from .convert import ConvertCommand from .download import DownloadCommand from .env import EnvironmentCommand @@ -37,6 +38,7 @@ def main(): ServeCommand.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser) AddNewModelCommand.register_subcommand(commands_parser) + AddNewModelLikeCommand.register_subcommand(commands_parser) LfsCommands.register_subcommand(commands_parser) # Let's go diff --git a/tests/fixtures/add_distilbert_like_config.json b/tests/fixtures/add_distilbert_like_config.json new file mode 100644 index 00000000000..812d2a635dd --- /dev/null +++ b/tests/fixtures/add_distilbert_like_config.json @@ -0,0 +1,19 @@ +{ + "add_copied_from": true, + "old_model_type": "distilbert", + "new_model_patterns": { + "model_name": "BERT New", + "checkpoint": "huggingface/bert-new-base", + "model_type": "bert-new", + "model_lower_cased": "bert_new", + "model_camel_cased": "BertNew", + "model_upper_cased": "BERT_NEW", + "config_class": "BertNewConfig", + "tokenizer_class": "DistilBertTokenizer" + }, + "frameworks": [ + "pt", + "tf", + "flax" + ] +} \ No newline at end of file diff --git a/tests/test_add_new_model_like.py b/tests/test_add_new_model_like.py new file mode 100644 index 00000000000..d1134bad521 --- /dev/null +++ b/tests/test_add_new_model_like.py @@ -0,0 +1,1342 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import tempfile +import unittest +from pathlib import Path + +import transformers +from transformers.commands.add_new_model_like import ( + ModelPatterns, + _re_class_func, + add_content_to_file, + add_content_to_text, + clean_frameworks_in_init, + duplicate_doc_file, + duplicate_module, + filter_framework_files, + find_base_model_checkpoint, + get_model_files, + get_module_from_file, + parse_module_content, + replace_model_patterns, + retrieve_info_for_model, + retrieve_model_classes, + simplify_replacements, +) +from transformers.testing_utils import require_flax, require_tf, require_torch + + +BERT_MODEL_FILES = { + "src/transformers/models/bert/__init__.py", + "src/transformers/models/bert/configuration_bert.py", + "src/transformers/models/bert/tokenization_bert.py", + "src/transformers/models/bert/tokenization_bert_fast.py", + "src/transformers/models/bert/modeling_bert.py", + "src/transformers/models/bert/modeling_flax_bert.py", + "src/transformers/models/bert/modeling_tf_bert.py", + "src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py", + "src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py", + "src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py", +} + +VIT_MODEL_FILES = { + "src/transformers/models/vit/__init__.py", + "src/transformers/models/vit/configuration_vit.py", + "src/transformers/models/vit/convert_dino_to_pytorch.py", + "src/transformers/models/vit/convert_vit_timm_to_pytorch.py", + "src/transformers/models/vit/feature_extraction_vit.py", + "src/transformers/models/vit/modeling_vit.py", + "src/transformers/models/vit/modeling_tf_vit.py", + "src/transformers/models/vit/modeling_flax_vit.py", +} + +WAV2VEC2_MODEL_FILES = { + "src/transformers/models/wav2vec2/__init__.py", + "src/transformers/models/wav2vec2/configuration_wav2vec2.py", + "src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py", + "src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py", + "src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py", + "src/transformers/models/wav2vec2/modeling_wav2vec2.py", + "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py", + "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py", + "src/transformers/models/wav2vec2/processing_wav2vec2.py", + "src/transformers/models/wav2vec2/tokenization_wav2vec2.py", +} + +REPO_PATH = Path(transformers.__path__[0]).parent.parent + + +@require_torch +@require_tf +@require_flax +class TestAddNewModelLike(unittest.TestCase): + def init_file(self, file_name, content): + with open(file_name, "w", encoding="utf-8") as f: + f.write(content) + + def check_result(self, file_name, expected_result): + with open(file_name, "r", encoding="utf-8") as f: + self.assertEqual(f.read(), expected_result) + + def test_re_class_func(self): + self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function") + self.assertEqual(_re_class_func.search("class MyClass:").groups()[0], "MyClass") + self.assertEqual(_re_class_func.search("class MyClass(SuperClass):").groups()[0], "MyClass") + + def test_model_patterns_defaults(self): + model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base") + + self.assertEqual(model_patterns.model_type, "gpt-new-new") + self.assertEqual(model_patterns.model_lower_cased, "gpt_new_new") + self.assertEqual(model_patterns.model_camel_cased, "GPTNewNew") + self.assertEqual(model_patterns.model_upper_cased, "GPT_NEW_NEW") + self.assertEqual(model_patterns.config_class, "GPTNewNewConfig") + self.assertIsNone(model_patterns.tokenizer_class) + self.assertIsNone(model_patterns.feature_extractor_class) + self.assertIsNone(model_patterns.processor_class) + + def test_parse_module_content(self): + test_code = """SOME_CONSTANT = a constant + +CONSTANT_DEFINED_ON_SEVERAL_LINES = [ + first_item, + second_item +] + +def function(args): + some code + +# Copied from transformers.some_module +class SomeClass: + some code +""" + + expected_parts = [ + "SOME_CONSTANT = a constant\n", + "CONSTANT_DEFINED_ON_SEVERAL_LINES = [\n first_item,\n second_item\n]", + "", + "def function(args):\n some code\n", + "# Copied from transformers.some_module\nclass SomeClass:\n some code\n", + ] + self.assertEqual(parse_module_content(test_code), expected_parts) + + def test_add_content_to_text(self): + test_text = """all_configs = { + "gpt": "GPTConfig", + "bert": "BertConfig", + "t5": "T5Config", +}""" + + expected = """all_configs = { + "gpt": "GPTConfig", + "gpt2": "GPT2Config", + "bert": "BertConfig", + "t5": "T5Config", +}""" + line = ' "gpt2": "GPT2Config",' + + self.assertEqual(add_content_to_text(test_text, line, add_before="bert"), expected) + self.assertEqual(add_content_to_text(test_text, line, add_before="bert", exact_match=True), test_text) + self.assertEqual( + add_content_to_text(test_text, line, add_before=' "bert": "BertConfig",', exact_match=True), expected + ) + self.assertEqual(add_content_to_text(test_text, line, add_before=re.compile('^\s*"bert":')), expected) + + self.assertEqual(add_content_to_text(test_text, line, add_after="gpt"), expected) + self.assertEqual(add_content_to_text(test_text, line, add_after="gpt", exact_match=True), test_text) + self.assertEqual( + add_content_to_text(test_text, line, add_after=' "gpt": "GPTConfig",', exact_match=True), expected + ) + self.assertEqual(add_content_to_text(test_text, line, add_after=re.compile('^\s*"gpt":')), expected) + + def test_add_content_to_file(self): + test_text = """all_configs = { + "gpt": "GPTConfig", + "bert": "BertConfig", + "t5": "T5Config", +}""" + + expected = """all_configs = { + "gpt": "GPTConfig", + "gpt2": "GPT2Config", + "bert": "BertConfig", + "t5": "T5Config", +}""" + line = ' "gpt2": "GPT2Config",' + + with tempfile.TemporaryDirectory() as tmp_dir: + file_name = os.path.join(tmp_dir, "code.py") + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_before="bert") + self.check_result(file_name, expected) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_before="bert", exact_match=True) + self.check_result(file_name, test_text) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_before=' "bert": "BertConfig",', exact_match=True) + self.check_result(file_name, expected) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_before=re.compile('^\s*"bert":')) + self.check_result(file_name, expected) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_after="gpt") + self.check_result(file_name, expected) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_after="gpt", exact_match=True) + self.check_result(file_name, test_text) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_after=' "gpt": "GPTConfig",', exact_match=True) + self.check_result(file_name, expected) + + self.init_file(file_name, test_text) + add_content_to_file(file_name, line, add_after=re.compile('^\s*"gpt":')) + self.check_result(file_name, expected) + + def test_simplify_replacements(self): + self.assertEqual(simplify_replacements([("Bert", "NewBert")]), [("Bert", "NewBert")]) + self.assertEqual( + simplify_replacements([("Bert", "NewBert"), ("bert", "new-bert")]), + [("Bert", "NewBert"), ("bert", "new-bert")], + ) + self.assertEqual( + simplify_replacements([("BertConfig", "NewBertConfig"), ("Bert", "NewBert"), ("bert", "new-bert")]), + [("Bert", "NewBert"), ("bert", "new-bert")], + ) + + def test_replace_model_patterns(self): + bert_model_patterns = ModelPatterns("Bert", "bert-base-cased") + new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base") + bert_test = '''class TFBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + is_parallelizable = True + supports_gradient_checkpointing = True + model_type = "bert" + +BERT_CONSTANT = "value" +''' + bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NewBertConfig + load_tf_weights = load_tf_weights_in_new_bert + base_model_prefix = "new_bert" + is_parallelizable = True + supports_gradient_checkpointing = True + model_type = "new-bert" + +NEW_BERT_CONSTANT = "value" +''' + + bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns) + self.assertEqual(bert_converted, bert_expected) + # Replacements are empty here since bert as been replaced by bert_new in some instances and bert-new + # in others. + self.assertEqual(replacements, "") + + # If we remove the model type, we will get replacements + bert_test = bert_test.replace(' model_type = "bert"\n', "") + bert_expected = bert_expected.replace(' model_type = "new-bert"\n', "") + bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns) + self.assertEqual(bert_converted, bert_expected) + self.assertEqual(replacements, "BERT->NEW_BERT,Bert->NewBert,bert->new_bert") + + gpt_model_patterns = ModelPatterns("GPT2", "gpt2") + new_gpt_model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base") + gpt_test = '''class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + +GPT2_CONSTANT = "value" +''' + + gpt_expected = '''class GPTNewNewPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTNewNewConfig + load_tf_weights = load_tf_weights_in_gpt_new_new + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + +GPT_NEW_NEW_CONSTANT = "value" +''' + + gpt_converted, replacements = replace_model_patterns(gpt_test, gpt_model_patterns, new_gpt_model_patterns) + self.assertEqual(gpt_converted, gpt_expected) + # Replacements are empty here since GPT2 as been replaced by GPTNewNew in some instances and GPT_NEW_NEW + # in others. + self.assertEqual(replacements, "") + + roberta_model_patterns = ModelPatterns("RoBERTa", "roberta-base", model_camel_cased="Roberta") + new_roberta_model_patterns = ModelPatterns( + "RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew" + ) + roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta +class RobertaModel(RobertaPreTrainedModel): + """ The base RoBERTa model. """ + checkpoint = roberta-base + base_model_prefix = "roberta" + ''' + roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew +class RobertaNewModel(RobertaNewPreTrainedModel): + """ The base RoBERTa-New model. """ + checkpoint = huggingface/roberta-new-base + base_model_prefix = "roberta_new" + ''' + roberta_converted, replacements = replace_model_patterns( + roberta_test, roberta_model_patterns, new_roberta_model_patterns + ) + self.assertEqual(roberta_converted, roberta_expected) + + def test_get_module_from_file(self): + self.assertEqual( + get_module_from_file("/git/transformers/src/transformers/models/bert/modeling_tf_bert.py"), + "transformers.models.bert.modeling_tf_bert", + ) + self.assertEqual( + get_module_from_file("/transformers/models/gpt2/modeling_gpt2.py"), + "transformers.models.gpt2.modeling_gpt2", + ) + with self.assertRaises(ValueError): + get_module_from_file("/models/gpt2/modeling_gpt2.py") + + def test_duplicate_module(self): + bert_model_patterns = ModelPatterns("Bert", "bert-base-cased") + new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base") + bert_test = '''class TFBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + is_parallelizable = True + supports_gradient_checkpointing = True + +BERT_CONSTANT = "value" +''' + bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NewBertConfig + load_tf_weights = load_tf_weights_in_new_bert + base_model_prefix = "new_bert" + is_parallelizable = True + supports_gradient_checkpointing = True + +NEW_BERT_CONSTANT = "value" +''' + bert_expected_with_copied_from = ( + "# Copied from transformers.bert_module.TFBertPreTrainedModel with Bert->NewBert,bert->new_bert\n" + + bert_expected + ) + with tempfile.TemporaryDirectory() as tmp_dir: + work_dir = os.path.join(tmp_dir, "transformers") + os.makedirs(work_dir) + file_name = os.path.join(work_dir, "bert_module.py") + dest_file_name = os.path.join(work_dir, "new_bert_module.py") + + self.init_file(file_name, bert_test) + duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns) + self.check_result(dest_file_name, bert_expected_with_copied_from) + + self.init_file(file_name, bert_test) + duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False) + self.check_result(dest_file_name, bert_expected) + + def test_duplicate_module_with_copied_from(self): + bert_model_patterns = ModelPatterns("Bert", "bert-base-cased") + new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base") + bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert +class TFBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + is_parallelizable = True + supports_gradient_checkpointing = True + +BERT_CONSTANT = "value" +''' + bert_expected = '''# Copied from transformers.models.xxx.XxxModel with Xxx->NewBert +class TFNewBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NewBertConfig + load_tf_weights = load_tf_weights_in_new_bert + base_model_prefix = "new_bert" + is_parallelizable = True + supports_gradient_checkpointing = True + +NEW_BERT_CONSTANT = "value" +''' + with tempfile.TemporaryDirectory() as tmp_dir: + work_dir = os.path.join(tmp_dir, "transformers") + os.makedirs(work_dir) + file_name = os.path.join(work_dir, "bert_module.py") + dest_file_name = os.path.join(work_dir, "new_bert_module.py") + + self.init_file(file_name, bert_test) + duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns) + # There should not be a new Copied from statement, the old one should be adapated. + self.check_result(dest_file_name, bert_expected) + + self.init_file(file_name, bert_test) + duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False) + self.check_result(dest_file_name, bert_expected) + + def test_filter_framework_files(self): + files = ["modeling_tf_bert.py", "modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"] + self.assertEqual(filter_framework_files(files), files) + self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files)) + + self.assertEqual(set(filter_framework_files(files, ["pt"])), {"modeling_bert.py", "configuration_bert.py"}) + self.assertEqual(set(filter_framework_files(files, ["tf"])), {"modeling_tf_bert.py", "configuration_bert.py"}) + self.assertEqual( + set(filter_framework_files(files, ["flax"])), {"modeling_flax_bert.py", "configuration_bert.py"} + ) + + self.assertEqual( + set(filter_framework_files(files, ["pt", "tf"])), + {"modeling_tf_bert.py", "modeling_bert.py", "configuration_bert.py"}, + ) + self.assertEqual( + set(filter_framework_files(files, ["tf", "flax"])), + {"modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"}, + ) + self.assertEqual( + set(filter_framework_files(files, ["pt", "flax"])), + {"modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"}, + ) + + def test_get_model_files(self): + # BERT + bert_files = get_model_files("bert") + + doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]} + self.assertEqual(model_files, BERT_MODEL_FILES) + + self.assertEqual(bert_files["module_name"], "bert") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]} + bert_test_files = { + "tests/test_tokenization_bert.py", + "tests/test_modeling_bert.py", + "tests/test_modeling_tf_bert.py", + "tests/test_modeling_flax_bert.py", + } + self.assertEqual(test_files, bert_test_files) + + # VIT + vit_files = get_model_files("vit") + doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]} + self.assertEqual(model_files, VIT_MODEL_FILES) + + self.assertEqual(vit_files["module_name"], "vit") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]} + vit_test_files = { + "tests/test_feature_extraction_vit.py", + "tests/test_modeling_vit.py", + "tests/test_modeling_tf_vit.py", + "tests/test_modeling_flax_vit.py", + } + self.assertEqual(test_files, vit_test_files) + + # Wav2Vec2 + wav2vec2_files = get_model_files("wav2vec2") + doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]} + self.assertEqual(model_files, WAV2VEC2_MODEL_FILES) + + self.assertEqual(wav2vec2_files["module_name"], "wav2vec2") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]} + wav2vec2_test_files = { + "tests/test_feature_extraction_wav2vec2.py", + "tests/test_modeling_wav2vec2.py", + "tests/test_modeling_tf_wav2vec2.py", + "tests/test_modeling_flax_wav2vec2.py", + "tests/test_processor_wav2vec2.py", + "tests/test_tokenization_wav2vec2.py", + } + self.assertEqual(test_files, wav2vec2_test_files) + + def test_get_model_files_only_pt(self): + # BERT + bert_files = get_model_files("bert", frameworks=["pt"]) + + doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]} + bert_model_files = BERT_MODEL_FILES - { + "src/transformers/models/bert/modeling_tf_bert.py", + "src/transformers/models/bert/modeling_flax_bert.py", + } + self.assertEqual(model_files, bert_model_files) + + self.assertEqual(bert_files["module_name"], "bert") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]} + bert_test_files = { + "tests/test_tokenization_bert.py", + "tests/test_modeling_bert.py", + } + self.assertEqual(test_files, bert_test_files) + + # VIT + vit_files = get_model_files("vit", frameworks=["pt"]) + doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]} + vit_model_files = VIT_MODEL_FILES - { + "src/transformers/models/vit/modeling_tf_vit.py", + "src/transformers/models/vit/modeling_flax_vit.py", + } + self.assertEqual(model_files, vit_model_files) + + self.assertEqual(vit_files["module_name"], "vit") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]} + vit_test_files = { + "tests/test_feature_extraction_vit.py", + "tests/test_modeling_vit.py", + } + self.assertEqual(test_files, vit_test_files) + + # Wav2Vec2 + wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"]) + doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]} + wav2vec2_model_files = WAV2VEC2_MODEL_FILES - { + "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py", + "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py", + } + self.assertEqual(model_files, wav2vec2_model_files) + + self.assertEqual(wav2vec2_files["module_name"], "wav2vec2") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]} + wav2vec2_test_files = { + "tests/test_feature_extraction_wav2vec2.py", + "tests/test_modeling_wav2vec2.py", + "tests/test_processor_wav2vec2.py", + "tests/test_tokenization_wav2vec2.py", + } + self.assertEqual(test_files, wav2vec2_test_files) + + def test_get_model_files_tf_and_flax(self): + # BERT + bert_files = get_model_files("bert", frameworks=["tf", "flax"]) + + doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]} + bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"} + self.assertEqual(model_files, bert_model_files) + + self.assertEqual(bert_files["module_name"], "bert") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]} + bert_test_files = { + "tests/test_tokenization_bert.py", + "tests/test_modeling_tf_bert.py", + "tests/test_modeling_flax_bert.py", + } + self.assertEqual(test_files, bert_test_files) + + # VIT + vit_files = get_model_files("vit", frameworks=["tf", "flax"]) + doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]} + vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"} + self.assertEqual(model_files, vit_model_files) + + self.assertEqual(vit_files["module_name"], "vit") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]} + vit_test_files = { + "tests/test_feature_extraction_vit.py", + "tests/test_modeling_tf_vit.py", + "tests/test_modeling_flax_vit.py", + } + self.assertEqual(test_files, vit_test_files) + + # Wav2Vec2 + wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"]) + doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx") + + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]} + wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"} + self.assertEqual(model_files, wav2vec2_model_files) + + self.assertEqual(wav2vec2_files["module_name"], "wav2vec2") + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]} + wav2vec2_test_files = { + "tests/test_feature_extraction_wav2vec2.py", + "tests/test_modeling_tf_wav2vec2.py", + "tests/test_modeling_flax_wav2vec2.py", + "tests/test_processor_wav2vec2.py", + "tests/test_tokenization_wav2vec2.py", + } + self.assertEqual(test_files, wav2vec2_test_files) + + def test_find_base_model_checkpoint(self): + self.assertEqual(find_base_model_checkpoint("bert"), "bert-base-uncased") + self.assertEqual(find_base_model_checkpoint("gpt2"), "gpt2") + + def test_retrieve_model_classes(self): + gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()} + expected_gpt_classes = { + "pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"}, + "tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"}, + "flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"}, + } + self.assertEqual(gpt_classes, expected_gpt_classes) + + del expected_gpt_classes["flax"] + gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["pt", "tf"]).items()} + self.assertEqual(gpt_classes, expected_gpt_classes) + + del expected_gpt_classes["pt"] + gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["tf"]).items()} + self.assertEqual(gpt_classes, expected_gpt_classes) + + def test_retrieve_info_for_model_with_bert(self): + bert_info = retrieve_info_for_model("bert") + bert_classes = [ + "BertForTokenClassification", + "BertForQuestionAnswering", + "BertForNextSentencePrediction", + "BertForSequenceClassification", + "BertForMaskedLM", + "BertForMultipleChoice", + "BertModel", + "BertForPreTraining", + "BertLMHeadModel", + ] + expected_model_classes = { + "pt": set(bert_classes), + "tf": {f"TF{m}" for m in bert_classes}, + "flax": {f"Flax{m}" for m in bert_classes[:-1]}, + } + + self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"}) + model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()} + self.assertEqual(model_classes, expected_model_classes) + + all_bert_files = bert_info["model_files"] + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]} + self.assertEqual(model_files, BERT_MODEL_FILES) + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]} + bert_test_files = { + "tests/test_tokenization_bert.py", + "tests/test_modeling_bert.py", + "tests/test_modeling_tf_bert.py", + "tests/test_modeling_flax_bert.py", + } + self.assertEqual(test_files, bert_test_files) + + doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx") + + self.assertEqual(all_bert_files["module_name"], "bert") + + bert_model_patterns = bert_info["model_patterns"] + self.assertEqual(bert_model_patterns.model_name, "BERT") + self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased") + self.assertEqual(bert_model_patterns.model_type, "bert") + self.assertEqual(bert_model_patterns.model_lower_cased, "bert") + self.assertEqual(bert_model_patterns.model_camel_cased, "Bert") + self.assertEqual(bert_model_patterns.model_upper_cased, "BERT") + self.assertEqual(bert_model_patterns.config_class, "BertConfig") + self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer") + self.assertIsNone(bert_model_patterns.feature_extractor_class) + self.assertIsNone(bert_model_patterns.processor_class) + + def test_retrieve_info_for_model_pt_tf_with_bert(self): + bert_info = retrieve_info_for_model("bert", frameworks=["pt", "tf"]) + bert_classes = [ + "BertForTokenClassification", + "BertForQuestionAnswering", + "BertForNextSentencePrediction", + "BertForSequenceClassification", + "BertForMaskedLM", + "BertForMultipleChoice", + "BertModel", + "BertForPreTraining", + "BertLMHeadModel", + ] + expected_model_classes = {"pt": set(bert_classes), "tf": {f"TF{m}" for m in bert_classes}} + + self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf"}) + model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()} + self.assertEqual(model_classes, expected_model_classes) + + all_bert_files = bert_info["model_files"] + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]} + bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_flax_bert.py"} + self.assertEqual(model_files, bert_model_files) + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]} + bert_test_files = { + "tests/test_tokenization_bert.py", + "tests/test_modeling_bert.py", + "tests/test_modeling_tf_bert.py", + } + self.assertEqual(test_files, bert_test_files) + + doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx") + + self.assertEqual(all_bert_files["module_name"], "bert") + + bert_model_patterns = bert_info["model_patterns"] + self.assertEqual(bert_model_patterns.model_name, "BERT") + self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased") + self.assertEqual(bert_model_patterns.model_type, "bert") + self.assertEqual(bert_model_patterns.model_lower_cased, "bert") + self.assertEqual(bert_model_patterns.model_camel_cased, "Bert") + self.assertEqual(bert_model_patterns.model_upper_cased, "BERT") + self.assertEqual(bert_model_patterns.config_class, "BertConfig") + self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer") + self.assertIsNone(bert_model_patterns.feature_extractor_class) + self.assertIsNone(bert_model_patterns.processor_class) + + def test_retrieve_info_for_model_with_vit(self): + vit_info = retrieve_info_for_model("vit") + vit_classes = ["ViTForImageClassification", "ViTModel"] + expected_model_classes = { + "pt": set(vit_classes), + "tf": {f"TF{m}" for m in vit_classes}, + "flax": {f"Flax{m}" for m in vit_classes}, + } + + self.assertEqual(set(vit_info["frameworks"]), {"pt", "tf", "flax"}) + model_classes = {k: set(v) for k, v in vit_info["model_classes"].items()} + self.assertEqual(model_classes, expected_model_classes) + + all_vit_files = vit_info["model_files"] + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["model_files"]} + self.assertEqual(model_files, VIT_MODEL_FILES) + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]} + vit_test_files = { + "tests/test_feature_extraction_vit.py", + "tests/test_modeling_vit.py", + "tests/test_modeling_tf_vit.py", + "tests/test_modeling_flax_vit.py", + } + self.assertEqual(test_files, vit_test_files) + + doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx") + + self.assertEqual(all_vit_files["module_name"], "vit") + + vit_model_patterns = vit_info["model_patterns"] + self.assertEqual(vit_model_patterns.model_name, "ViT") + self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224") + self.assertEqual(vit_model_patterns.model_type, "vit") + self.assertEqual(vit_model_patterns.model_lower_cased, "vit") + self.assertEqual(vit_model_patterns.model_camel_cased, "ViT") + self.assertEqual(vit_model_patterns.model_upper_cased, "VIT") + self.assertEqual(vit_model_patterns.config_class, "ViTConfig") + self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor") + self.assertIsNone(vit_model_patterns.tokenizer_class) + self.assertIsNone(vit_model_patterns.processor_class) + + def test_retrieve_info_for_model_with_wav2vec2(self): + wav2vec2_info = retrieve_info_for_model("wav2vec2") + wav2vec2_classes = [ + "Wav2Vec2Model", + "Wav2Vec2ForPreTraining", + "Wav2Vec2ForAudioFrameClassification", + "Wav2Vec2ForCTC", + "Wav2Vec2ForMaskedLM", + "Wav2Vec2ForSequenceClassification", + "Wav2Vec2ForXVector", + ] + expected_model_classes = { + "pt": set(wav2vec2_classes), + "tf": {f"TF{m}" for m in wav2vec2_classes[:1]}, + "flax": {f"Flax{m}" for m in wav2vec2_classes[:2]}, + } + + self.assertEqual(set(wav2vec2_info["frameworks"]), {"pt", "tf", "flax"}) + model_classes = {k: set(v) for k, v in wav2vec2_info["model_classes"].items()} + self.assertEqual(model_classes, expected_model_classes) + + all_wav2vec2_files = wav2vec2_info["model_files"] + model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["model_files"]} + self.assertEqual(model_files, WAV2VEC2_MODEL_FILES) + + test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]} + wav2vec2_test_files = { + "tests/test_feature_extraction_wav2vec2.py", + "tests/test_modeling_wav2vec2.py", + "tests/test_modeling_tf_wav2vec2.py", + "tests/test_modeling_flax_wav2vec2.py", + "tests/test_processor_wav2vec2.py", + "tests/test_tokenization_wav2vec2.py", + } + self.assertEqual(test_files, wav2vec2_test_files) + + doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH)) + self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx") + + self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2") + + wav2vec2_model_patterns = wav2vec2_info["model_patterns"] + self.assertEqual(wav2vec2_model_patterns.model_name, "Wav2Vec2") + self.assertEqual(wav2vec2_model_patterns.checkpoint, "facebook/wav2vec2-base-960h") + self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2") + self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2") + self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2") + self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2") + self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config") + self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor") + self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor") + self.assertEqual(wav2vec2_model_patterns.tokenizer_class, "Wav2Vec2CTCTokenizer") + + def test_clean_frameworks_in_init_with_gpt(self): + test_init = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available + +_import_structure = { + "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], + "tokenization_gpt2": ["GPT2Tokenizer"], +} + +if is_tokenizers_available(): + _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"] + +if is_torch_available(): + _import_structure["modeling_gpt2"] = ["GPT2Model"] + +if is_tf_available(): + _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"] + +if is_flax_available(): + _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"] + +if TYPE_CHECKING: + from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig + from .tokenization_gpt2 import GPT2Tokenizer + + if is_tokenizers_available(): + from .tokenization_gpt2_fast import GPT2TokenizerFast + + if is_torch_available(): + from .modeling_gpt2 import GPT2Model + + if is_tf_available(): + from .modeling_tf_gpt2 import TFGPT2Model + + if is_flax_available(): + from .modeling_flax_gpt2 import FlaxGPT2Model + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + init_no_tokenizer = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available + +_import_structure = { + "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], +} + +if is_torch_available(): + _import_structure["modeling_gpt2"] = ["GPT2Model"] + +if is_tf_available(): + _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"] + +if is_flax_available(): + _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"] + +if TYPE_CHECKING: + from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig + + if is_torch_available(): + from .modeling_gpt2 import GPT2Model + + if is_tf_available(): + from .modeling_tf_gpt2 import TFGPT2Model + + if is_flax_available(): + from .modeling_flax_gpt2 import FlaxGPT2Model + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + init_pt_only = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available + +_import_structure = { + "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], + "tokenization_gpt2": ["GPT2Tokenizer"], +} + +if is_tokenizers_available(): + _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"] + +if is_torch_available(): + _import_structure["modeling_gpt2"] = ["GPT2Model"] + +if TYPE_CHECKING: + from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig + from .tokenization_gpt2 import GPT2Tokenizer + + if is_tokenizers_available(): + from .tokenization_gpt2_fast import GPT2TokenizerFast + + if is_torch_available(): + from .modeling_gpt2 import GPT2Model + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + init_pt_only_no_tokenizer = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available + +_import_structure = { + "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], +} + +if is_torch_available(): + _import_structure["modeling_gpt2"] = ["GPT2Model"] + +if TYPE_CHECKING: + from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig + + if is_torch_available(): + from .modeling_gpt2 import GPT2Model + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + with tempfile.TemporaryDirectory() as tmp_dir: + file_name = os.path.join(tmp_dir, "__init__.py") + + self.init_file(file_name, test_init) + clean_frameworks_in_init(file_name, keep_processing=False) + self.check_result(file_name, init_no_tokenizer) + + self.init_file(file_name, test_init) + clean_frameworks_in_init(file_name, frameworks=["pt"]) + self.check_result(file_name, init_pt_only) + + self.init_file(file_name, test_init) + clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False) + self.check_result(file_name, init_pt_only_no_tokenizer) + + def test_clean_frameworks_in_init_with_vit(self): + test_init = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available + +_import_structure = { + "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"] + +if is_torch_available(): + _import_structure["modeling_vit"] = ["ViTModel"] + +if is_tf_available(): + _import_structure["modeling_tf_vit"] = ["TFViTModel"] + +if is_flax_available(): + _import_structure["modeling_flax_vit"] = ["FlaxViTModel"] + +if TYPE_CHECKING: + from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig + + if is_vision_available(): + from .feature_extraction_vit import ViTFeatureExtractor + + if is_torch_available(): + from .modeling_vit import ViTModel + + if is_tf_available(): + from .modeling_tf_vit import ViTModel + + if is_flax_available(): + from .modeling_flax_vit import ViTModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + init_no_feature_extractor = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available + +_import_structure = { + "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], +} + +if is_torch_available(): + _import_structure["modeling_vit"] = ["ViTModel"] + +if is_tf_available(): + _import_structure["modeling_tf_vit"] = ["TFViTModel"] + +if is_flax_available(): + _import_structure["modeling_flax_vit"] = ["FlaxViTModel"] + +if TYPE_CHECKING: + from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig + + if is_torch_available(): + from .modeling_vit import ViTModel + + if is_tf_available(): + from .modeling_tf_vit import ViTModel + + if is_flax_available(): + from .modeling_flax_vit import ViTModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + init_pt_only = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available, is_vision_available + +_import_structure = { + "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"] + +if is_torch_available(): + _import_structure["modeling_vit"] = ["ViTModel"] + +if TYPE_CHECKING: + from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig + + if is_vision_available(): + from .feature_extraction_vit import ViTFeatureExtractor + + if is_torch_available(): + from .modeling_vit import ViTModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + init_pt_only_no_feature_extractor = """ +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available + +_import_structure = { + "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], +} + +if is_torch_available(): + _import_structure["modeling_vit"] = ["ViTModel"] + +if TYPE_CHECKING: + from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig + + if is_torch_available(): + from .modeling_vit import ViTModel + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) +""" + + with tempfile.TemporaryDirectory() as tmp_dir: + file_name = os.path.join(tmp_dir, "__init__.py") + + self.init_file(file_name, test_init) + clean_frameworks_in_init(file_name, keep_processing=False) + self.check_result(file_name, init_no_feature_extractor) + + self.init_file(file_name, test_init) + clean_frameworks_in_init(file_name, frameworks=["pt"]) + self.check_result(file_name, init_pt_only) + + self.init_file(file_name, test_init) + clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False) + self.check_result(file_name, init_pt_only_no_feature_extractor) + + def test_duplicate_doc_file(self): + test_doc = """ +# GPT2 + +## Overview + +Overview of the model. + +## GPT2Config + +[[autodoc]] GPT2Config + +## GPT2Tokenizer + +[[autodoc]] GPT2Tokenizer + - save_vocabulary + +## GPT2TokenizerFast + +[[autodoc]] GPT2TokenizerFast + +## GPT2 specific outputs + +[[autodoc]] models.gpt2.modeling_gpt2.GPT2DoubleHeadsModelOutput + +[[autodoc]] models.gpt2.modeling_tf_gpt2.TFGPT2DoubleHeadsModelOutput + +## GPT2Model + +[[autodoc]] GPT2Model + - forward + +## TFGPT2Model + +[[autodoc]] TFGPT2Model + - call + +## FlaxGPT2Model + +[[autodoc]] FlaxGPT2Model + - __call__ + +""" + test_new_doc = """ +# GPT-New New + +## Overview + +The GPT-New New model was proposed in [() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](). +The original code can be found [here](). + + +## GPTNewNewConfig + +[[autodoc]] GPTNewNewConfig + +## GPTNewNewTokenizer + +[[autodoc]] GPTNewNewTokenizer + - save_vocabulary + +## GPTNewNewTokenizerFast + +[[autodoc]] GPTNewNewTokenizerFast + +## GPTNewNew specific outputs + +[[autodoc]] models.gpt_new_new.modeling_gpt_new_new.GPTNewNewDoubleHeadsModelOutput + +[[autodoc]] models.gpt_new_new.modeling_tf_gpt_new_new.TFGPTNewNewDoubleHeadsModelOutput + +## GPTNewNewModel + +[[autodoc]] GPTNewNewModel + - forward + +## TFGPTNewNewModel + +[[autodoc]] TFGPTNewNewModel + - call + +## FlaxGPTNewNewModel + +[[autodoc]] FlaxGPTNewNewModel + - __call__ + +""" + + with tempfile.TemporaryDirectory() as tmp_dir: + doc_file = os.path.join(tmp_dir, "gpt2.mdx") + new_doc_file = os.path.join(tmp_dir, "gpt-new-new.mdx") + + gpt2_model_patterns = ModelPatterns("GPT2", "gpt2", tokenizer_class="GPT2Tokenizer") + new_model_patterns = ModelPatterns( + "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPTNewNewTokenizer" + ) + + self.init_file(doc_file, test_doc) + duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns) + self.check_result(new_doc_file, test_new_doc) + + test_new_doc_pt_only = test_new_doc.replace( + """ +## TFGPTNewNewModel + +[[autodoc]] TFGPTNewNewModel + - call + +## FlaxGPTNewNewModel + +[[autodoc]] FlaxGPTNewNewModel + - __call__ + +""", + "", + ) + self.init_file(doc_file, test_doc) + duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"]) + self.check_result(new_doc_file, test_new_doc_pt_only) + + test_new_doc_no_tok = test_new_doc.replace( + """ +## GPTNewNewTokenizer + +[[autodoc]] GPTNewNewTokenizer + - save_vocabulary + +## GPTNewNewTokenizerFast + +[[autodoc]] GPTNewNewTokenizerFast +""", + "", + ) + new_model_patterns = ModelPatterns( + "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPT2Tokenizer" + ) + self.init_file(doc_file, test_doc) + duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns) + print(test_new_doc_no_tok) + self.check_result(new_doc_file, test_new_doc_no_tok) + + test_new_doc_pt_only_no_tok = test_new_doc_no_tok.replace( + """ +## TFGPTNewNewModel + +[[autodoc]] TFGPTNewNewModel + - call + +## FlaxGPTNewNewModel + +[[autodoc]] FlaxGPTNewNewModel + - __call__ + +""", + "", + ) + self.init_file(doc_file, test_doc) + duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"]) + self.check_result(new_doc_file, test_new_doc_pt_only_no_tok) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 358d638a327..1036981a628 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -243,6 +243,7 @@ def create_reverse_dependency_map(): # Any module file that has a test name which can't be inferred automatically from its name should go here. A better # approach is to (re-)name the test file accordingly, and second best to add the correspondence map here. SPECIAL_MODULE_TO_TEST_MAP = { + "commands/add_new_model_like.py": "test_add_new_model_like.py", "configuration_utils.py": "test_configuration_common.py", "convert_graph_to_onnx.py": "test_onnx.py", "data/data_collator.py": "test_data_collator.py",