diff --git a/.github/workflows/add-model-like.yml b/.github/workflows/add-model-like.yml
new file mode 100644
index 00000000000..89f6e840aae
--- /dev/null
+++ b/.github/workflows/add-model-like.yml
@@ -0,0 +1,60 @@
+name: Add model like runner
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ paths:
+ - "src/**"
+ - "tests/**"
+ - ".github/**"
+ types: [opened, synchronize, reopened]
+
+jobs:
+ run_tests_templates:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Loading cache.
+ uses: actions/cache@v2
+ id: cache
+ with:
+ path: ~/.cache/pip
+ key: v1-tests_model_like
+ restore-keys: |
+ v1-tests_model_like-${{ hashFiles('setup.py') }}
+ v1-tests_model_like
+
+ - name: Install dependencies
+ run: |
+ pip install --upgrade pip!=21.3
+ sudo apt -y update && sudo apt install -y libsndfile1-dev
+ pip install .[dev]
+
+ - name: Create model files
+ run: |
+ transformers-cli add-new-model-like --config_file tests/fixtures/add_distilbert_like_config.json --path_to_repo .
+ make style
+ make fix-copies
+
+ - name: Run all PyTorch modeling test
+ run: |
+ python -m pytest -n 2 --dist=loadfile -s --make-reports=tests_new_models tests/test_modeling_bert_new.py
+
+ - name: Run style changes
+ run: |
+ git fetch origin master:master
+ make style && make quality && make repo-consistency
+
+ - name: Failure short reports
+ if: ${{ always() }}
+ run: cat reports/tests_new_models_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v2
+ with:
+ name: run_all_tests_new_models_test_reports
+ path: reports
diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py
new file mode 100644
index 00000000000..e443a235c42
--- /dev/null
+++ b/src/transformers/commands/add_new_model_like.py
@@ -0,0 +1,1478 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import difflib
+import json
+import os
+import re
+from argparse import ArgumentParser, Namespace
+from dataclasses import dataclass
+from itertools import chain
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
+
+import transformers.models.auto as auto_module
+from transformers.models.auto.configuration_auto import model_type_to_module_name
+
+from ..utils import logging
+from . import BaseTransformersCLICommand
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+TRANSFORMERS_PATH = Path(__file__).parent.parent
+REPO_PATH = TRANSFORMERS_PATH.parent.parent
+
+
+@dataclass
+class ModelPatterns:
+ """
+ Holds the basic information about a new model for the add-new-model-like command.
+
+ Args:
+ model_name (`str`): The model name.
+ checkpoint (`str`): The checkpoint to use for doc examples.
+ model_type (`str`, *optional*):
+ The model type, the identifier used internally in the library like `bert` or `xlm-roberta`. Will default to
+ `model_name` lowercased with spaces replaced with minuses (-).
+ model_lower_cased (`str`, *optional*):
+ The lowercased version of the model name, to use for the module name or function names. Will default to
+ `model_name` lowercased with spaces and minuses replaced with underscores.
+ model_camel_cased (`str`, *optional*):
+ The camel-cased version of the model name, to use for the class names. Will default to `model_name`
+ camel-cased (with spaces and minuses both considered as word separators.
+ model_upper_cased (`str`, *optional*):
+ The uppercased version of the model name, to use for the constant names. Will default to `model_name`
+ uppercased with spaces and minuses replaced with underscores.
+ config_class (`str`, *optional*):
+ The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`.
+ tokenizer_class (`str`, *optional*):
+ The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer).
+ feature_extractor_class (`str`, *optional*):
+ The feature extractor class associated with this model (leave to `None` for models that don't use a feature
+ extractor).
+ processor_class (`str`, *optional*):
+ The processor class associated with this model (leave to `None` for models that don't use a processor).
+ """
+
+ model_name: str
+ checkpoint: str
+ model_type: Optional[str] = None
+ model_lower_cased: Optional[str] = None
+ model_camel_cased: Optional[str] = None
+ model_upper_cased: Optional[str] = None
+ config_class: Optional[str] = None
+ tokenizer_class: Optional[str] = None
+ feature_extractor_class: Optional[str] = None
+ processor_class: Optional[str] = None
+
+ def __post_init__(self):
+ if self.model_type is None:
+ self.model_type = self.model_name.lower().replace(" ", "-")
+ if self.model_lower_cased is None:
+ self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_")
+ if self.model_camel_cased is None:
+ # Split the model name on - and space
+ words = self.model_name.split(" ")
+ words = list(chain(*[w.split("-") for w in words]))
+ # Make sure each word is capitalized
+ words = [w[0].upper() + w[1:] for w in words]
+ self.model_camel_cased = "".join(words)
+ if self.model_upper_cased is None:
+ self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_")
+ if self.config_class is None:
+ self.config_class = f"{self.model_camel_cased}Config"
+
+
+ATTRIBUTE_TO_PLACEHOLDER = {
+ "config_class": "[CONFIG_CLASS]",
+ "tokenizer_class": "[TOKENIZER_CLASS]",
+ "feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
+ "processor_class": "[PROCESSOR_CLASS]",
+ "checkpoint": "[CHECKPOINT]",
+ "model_type": "[MODEL_TYPE]",
+ "model_upper_cased": "[MODEL_UPPER_CASED]",
+ "model_camel_cased": "[MODEL_CAMELCASED]",
+ "model_lower_cased": "[MODEL_LOWER_CASED]",
+ "model_name": "[MODEL_NAME]",
+}
+
+
+def is_empty_line(line: str) -> bool:
+ """
+ Determines whether a line is empty or not.
+ """
+ return len(line) == 0 or line.isspace()
+
+
+def find_indent(line: str) -> int:
+ """
+ Returns the number of spaces that start a line indent.
+ """
+ search = re.search("^(\s*)(?:\S|$)", line)
+ if search is None:
+ return 0
+ return len(search.groups()[0])
+
+
+def parse_module_content(content: str) -> List[str]:
+ """
+ Parse the content of a module in the list of objects it defines.
+
+ Args:
+ content (`str`): The content to parse
+
+ Returns:
+ `List[str]`: The list of objects defined in the module.
+ """
+ objects = []
+ current_object = []
+ lines = content.split("\n")
+ # Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
+ end_markers = [")", "]", "}", '"""']
+
+ for line in lines:
+ # End of an object
+ is_valid_object = len(current_object) > 0
+ if is_valid_object and len(current_object) == 1:
+ is_valid_object = not current_object[0].startswith("# Copied from")
+ if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object:
+ # Closing parts should be included in current object
+ if line in end_markers:
+ current_object.append(line)
+ objects.append("\n".join(current_object))
+ current_object = []
+ else:
+ objects.append("\n".join(current_object))
+ current_object = [line]
+ else:
+ current_object.append(line)
+
+ # Add last object
+ if len(current_object) > 0:
+ objects.append("\n".join(current_object))
+
+ return objects
+
+
+def add_content_to_text(
+ text: str,
+ content: str,
+ add_after: Optional[Union[str, Pattern]] = None,
+ add_before: Optional[Union[str, Pattern]] = None,
+ exact_match: bool = False,
+) -> str:
+ """
+ A utility to add some content inside a given text.
+
+ Args:
+ text (`str`): The text in which we want to insert some content.
+ content (`str`): The content to add.
+ add_after (`str` or `Pattern`):
+ The pattern to test on a line of `text`, the new content is added after the first instance matching it.
+ add_before (`str` or `Pattern`):
+ The pattern to test on a line of `text`, the new content is added before the first instance matching it.
+ exact_match (`bool`, *optional*, defaults to `False`):
+ A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
+ otherwise, if `add_after`/`add_before` is present in the line.
+
+
+
+ The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
+
+
+
+ Returns:
+ `str`: The text with the new content added if a match was found.
+ """
+ if add_after is None and add_before is None:
+ raise ValueError("You need to pass either `add_after` or `add_before`")
+ if add_after is not None and add_before is not None:
+ raise ValueError("You can't pass both `add_after` or `add_before`")
+ pattern = add_after if add_before is None else add_before
+
+ def this_is_the_line(line):
+ if isinstance(pattern, Pattern):
+ return pattern.search(line) is not None
+ elif exact_match:
+ return pattern == line
+ else:
+ return pattern in line
+
+ new_lines = []
+ for line in text.split("\n"):
+ if this_is_the_line(line):
+ if add_before is not None:
+ new_lines.append(content)
+ new_lines.append(line)
+ if add_after is not None:
+ new_lines.append(content)
+ else:
+ new_lines.append(line)
+
+ return "\n".join(new_lines)
+
+
+def add_content_to_file(
+ file_name: Union[str, os.PathLike],
+ content: str,
+ add_after: Optional[Union[str, Pattern]] = None,
+ add_before: Optional[Union[str, Pattern]] = None,
+ exact_match: bool = False,
+):
+ """
+ A utility to add some content inside a given file.
+
+ Args:
+ file_name (`str` or `os.PathLike`): The name of the file in which we want to insert some content.
+ content (`str`): The content to add.
+ add_after (`str` or `Pattern`):
+ The pattern to test on a line of `text`, the new content is added after the first instance matching it.
+ add_before (`str` or `Pattern`):
+ The pattern to test on a line of `text`, the new content is added before the first instance matching it.
+ exact_match (`bool`, *optional*, defaults to `False`):
+ A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
+ otherwise, if `add_after`/`add_before` is present in the line.
+
+
+
+ The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
+
+
+ """
+ with open(file_name, "r", encoding="utf-8") as f:
+ old_content = f.read()
+
+ new_content = add_content_to_text(
+ old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match
+ )
+
+ with open(file_name, "w", encoding="utf-8") as f:
+ f.write(new_content)
+
+
+def replace_model_patterns(
+ text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns
+) -> Tuple[str, str]:
+ """
+ Replace all patterns present in a given text.
+
+ Args:
+ text (`str`): The text to treat.
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+
+ Returns:
+ `Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it.
+ """
+ # The order is crucially important as we will check and replace in that order. For instance the config probably
+ # contains the camel-cased named, but will be treated before.
+ attributes_to_check = ["config_class"]
+ # Add relevant preprocessing classes
+ for attr in ["tokenizer_class", "feature_extractor_class", "processor_class"]:
+ if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
+ attributes_to_check.append(attr)
+
+ # Special cases for checkpoint and model_type
+ if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]:
+ attributes_to_check.append("checkpoint")
+ if old_model_patterns.model_type != old_model_patterns.model_lower_cased:
+ attributes_to_check.append("model_type")
+ else:
+ text = re.sub(
+ fr'(\s*)model_type = "{old_model_patterns.model_type}"',
+ r'\1model_type = "[MODEL_TYPE]"',
+ text,
+ )
+
+ # Special case when the model camel cased and upper cased names are the same for the old model (like for GPT2) but
+ # not the new one. We can't just do a replace in all the text and will need a special regex
+ if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
+ old_model_value = old_model_patterns.model_upper_cased
+ if re.search(fr"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
+ text = re.sub(fr"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
+ else:
+ attributes_to_check.append("model_upper_cased")
+
+ attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"])
+
+ # Now let's replace every other attribute by their placeholder
+ for attr in attributes_to_check:
+ text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr])
+
+ # Finally we can replace the placeholder byt the new values.
+ replacements = []
+ for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items():
+ if placeholder in text:
+ replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)))
+ text = text.replace(placeholder, getattr(new_model_patterns, attr))
+
+ # If we have two inconsistent replacements, we don't return anything (ex: GPT2->GPT_NEW and GPT2->GPTNew)
+ old_replacement_values = [old for old, new in replacements]
+ if len(set(old_replacement_values)) != len(old_replacement_values):
+ return text, ""
+
+ replacements = simplify_replacements(replacements)
+ replacements = [f"{old}->{new}" for old, new in replacements]
+ return text, ",".join(replacements)
+
+
+def simplify_replacements(replacements):
+ """
+ Simplify a list of replacement patterns to make sure there are no needless ones.
+
+ For instance in the sequence "Bert->BertNew, BertConfig->BertNewConfig, bert->bert_new", the replacement
+ "BertConfig->BertNewConfig" is implied by "Bert->BertNew" so not needed.
+
+ Args:
+ replacements (`List[Tuple[str, str]]`): List of patterns (old, new)
+
+ Returns:
+ `List[Tuple[str, str]]`: The list of patterns simplified.
+ """
+ if len(replacements) <= 1:
+ # Nothing to simplify
+ return replacements
+
+ # Next let's sort replacements by length as a replacement can only "imply" another replacement if it's shorter.
+ replacements.sort(key=lambda x: len(x[0]))
+
+ idx = 0
+ while idx < len(replacements):
+ old, new = replacements[idx]
+ # Loop through all replacements after
+ j = idx + 1
+ while j < len(replacements):
+ old_2, new_2 = replacements[j]
+ # If the replacement is implied by the current one, we can drop it.
+ if old_2.replace(old, new) == new_2:
+ replacements.pop(j)
+ else:
+ j += 1
+ idx += 1
+
+ return replacements
+
+
+def get_module_from_file(module_file: Union[str, os.PathLike]) -> str:
+ """
+ Returns the module name corresponding to a module file.
+ """
+ full_module_path = Path(module_file).absolute()
+ module_parts = full_module_path.with_suffix("").parts
+
+ # Find the first part named transformers, starting from the end.
+ idx = len(module_parts) - 1
+ while idx >= 0 and module_parts[idx] != "transformers":
+ idx -= 1
+ if idx < 0:
+ raise ValueError(f"{module_file} is not a transformers module.")
+
+ return ".".join(module_parts[idx:])
+
+
+SPECIAL_PATTERNS = {
+ "_CHECKPOINT_FOR_DOC =": "checkpoint",
+ "_CONFIG_FOR_DOC =": "config_class",
+ "_TOKENIZER_FOR_DOC =": "tokenizer_class",
+ "_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class",
+ "_PROCESSOR_FOR_DOC =": "processor_class",
+}
+
+
+_re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE)
+
+
+def duplicate_module(
+ module_file: Union[str, os.PathLike],
+ old_model_patterns: ModelPatterns,
+ new_model_patterns: ModelPatterns,
+ dest_file: Optional[str] = None,
+ add_copied_from: bool = True,
+):
+ """
+ Create a new module from an existing one and adapting all function and classes names from old patterns to new ones.
+
+ Args:
+ module_file (`str` or `os.PathLike`): Path to the module to duplicate.
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+ dest_file (`str` or `os.PathLike`, *optional*): Path to the new module.
+ add_copied_from (`bool`, *optional*, defaults to `True`):
+ Whether or not to add `# Copied from` statements in the duplicated module.
+ """
+ if dest_file is None:
+ dest_file = str(module_file).replace(
+ old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
+ )
+
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ objects = parse_module_content(content)
+
+ # Loop and treat all objects
+ new_objects = []
+ for obj in objects:
+ # Special cases
+ if "PRETRAINED_CONFIG_ARCHIVE_MAP = {" in obj:
+ # docstyle-ignore
+ obj = (
+ f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP = "
+ + "{"
+ + f"""
+ "{new_model_patterns.checkpoint}": "https://huggingface.co/{new_model_patterns.checkpoint}/resolve/main/config.json",
+"""
+ + "}\n"
+ )
+ new_objects.append(obj)
+ continue
+ elif "PRETRAINED_MODEL_ARCHIVE_LIST = [" in obj:
+ if obj.startswith("TF_"):
+ prefix = "TF_"
+ elif obj.startswith("FLAX_"):
+ prefix = "FLAX_"
+ else:
+ prefix = ""
+ # docstyle-ignore
+ obj = f"""{prefix}{new_model_patterns.model_upper_cased}_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "{new_model_patterns.checkpoint}",
+ # See all {new_model_patterns.model_name} models at https://huggingface.co/models?filter={new_model_patterns.model_type}
+]
+"""
+ new_objects.append(obj)
+ continue
+
+ special_pattern = False
+ for pattern, attr in SPECIAL_PATTERNS.items():
+ if pattern in obj:
+ obj = obj.replace(getattr(old_model_patterns, attr), getattr(new_model_patterns, attr))
+ new_objects.append(obj)
+ special_pattern = True
+ break
+
+ if special_pattern:
+ continue
+
+ # Regular classes functions
+ old_obj = obj
+ obj, replacement = replace_model_patterns(obj, old_model_patterns, new_model_patterns)
+ has_copied_from = re.search("^#\s+Copied from", obj, flags=re.MULTILINE) is not None
+ if add_copied_from and not has_copied_from and _re_class_func.search(obj) is not None and len(replacement) > 0:
+ # Copied from statement must be added just before the class/function definition, which may not be the
+ # first line because of decorators.
+ module_name = get_module_from_file(module_file)
+ old_object_name = _re_class_func.search(old_obj).groups()[0]
+ obj = add_content_to_text(
+ obj, f"# Copied from {module_name}.{old_object_name} with {replacement}", add_before=_re_class_func
+ )
+ # In all cases, we remove Copied from statement with indent on methods.
+ obj = re.sub("\n[ ]+# Copied from [^\n]*\n", "\n", obj)
+
+ new_objects.append(obj)
+
+ with open(dest_file, "w", encoding="utf-8") as f:
+ content = f.write("\n".join(new_objects))
+
+
+def filter_framework_files(
+ files: List[Union[str, os.PathLike]], frameworks: Optional[List[str]] = None
+) -> List[Union[str, os.PathLike]]:
+ """
+ Filter a list of files to only keep the ones corresponding to a list of frameworks.
+
+ Args:
+ files (`List[Union[str, os.PathLike]]`): The list of files to filter.
+ frameworks (`List[str]`, *optional*): The list of allowed frameworks.
+
+ Returns:
+ `List[Union[str, os.PathLike]]`: The list of filtered files.
+ """
+ if frameworks is None:
+ return files
+
+ framework_to_file = {}
+ others = []
+ for f in files:
+ parts = Path(f).name.split("_")
+ if "modeling" not in parts:
+ others.append(f)
+ continue
+ if "tf" in parts:
+ framework_to_file["tf"] = f
+ elif "flax" in parts:
+ framework_to_file["flax"] = f
+ else:
+ framework_to_file["pt"] = f
+
+ return [framework_to_file[f] for f in frameworks] + others
+
+
+def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, Union[Path, List[Path]]]:
+ """
+ Retrieves all the files associated to a model.
+
+ Args:
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
+ frameworks (`List[str]`, *optional*):
+ If passed, will only keep the model files corresponding to the passed frameworks.
+
+ Returns:
+ `Dict[str, Union[Path, List[Path]]]`: A dictionary with the following keys:
+ - **doc_file** -- The documentation file for the model.
+ - **model_files** -- All the files in the model module.
+ - **test_files** -- The test files for the model.
+ """
+ module_name = model_type_to_module_name(model_type)
+
+ model_module = TRANSFORMERS_PATH / "models" / module_name
+ model_files = list(model_module.glob("*.py"))
+ model_files = filter_framework_files(model_files, frameworks=frameworks)
+
+ doc_file = REPO_PATH / "docs" / "source" / "model_doc" / f"{model_type}.mdx"
+
+ # Basic pattern for test files
+ test_files = [
+ f"test_modeling_{module_name}.py",
+ f"test_modeling_tf_{module_name}.py",
+ f"test_modeling_flax_{module_name}.py",
+ f"test_tokenization_{module_name}.py",
+ f"test_feature_extraction_{module_name}.py",
+ f"test_processor_{module_name}.py",
+ ]
+ test_files = filter_framework_files(test_files, frameworks=frameworks)
+ # Add the test directory
+ test_files = [REPO_PATH / "tests" / f for f in test_files]
+ # Filter by existing files
+ test_files = [f for f in test_files if f.exists()]
+
+ return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
+
+
+_re_checkpoint_for_doc = re.compile("^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE)
+
+
+def find_base_model_checkpoint(
+ model_type: str, model_files: Optional[Dict[str, Union[Path, List[Path]]]] = None
+) -> str:
+ """
+ Finds the model checkpoint used in the docstrings for a given model.
+
+ Args:
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
+ model_files (`Dict[str, Union[Path, List[Path]]`, *optional*):
+ The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed.
+
+ Returns:
+ `str`: The checkpoint used.
+ """
+ if model_files is None:
+ model_files = get_model_files(model_type)
+ module_files = model_files["model_files"]
+ for fname in module_files:
+ if "modeling" not in str(fname):
+ continue
+
+ with open(fname, "r", encoding="utf-8") as f:
+ content = f.read()
+ if _re_checkpoint_for_doc.search(content) is not None:
+ checkpoint = _re_checkpoint_for_doc.search(content).groups()[0]
+ # Remove quotes
+ checkpoint = checkpoint.replace('"', "")
+ checkpoint = checkpoint.replace("'", "")
+ return checkpoint
+
+ # TODO: Find some kind of fallback if there is no _CHECKPOINT_FOR_DOC in any of the modeling file.
+ return ""
+
+
+_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
+
+
+def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, List[str]]:
+ """
+ Retrieve the model classes associated to a given model.
+
+ Args:
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
+ frameworks (`List[str]`, *optional*):
+ The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict
+ the classes returned.
+
+ Returns:
+ `Dict[str, List[str]]`: A dictionary with one key per framework and the list of model classes associated to
+ that framework as values.
+ """
+ if frameworks is None:
+ frameworks = ["pt", "tf", "flax"]
+
+ modules = {
+ "pt": auto_module.modeling_auto,
+ "tf": auto_module.modeling_tf_auto,
+ "flax": auto_module.modeling_flax_auto,
+ }
+
+ model_classes = {}
+ for framework in frameworks:
+ new_model_classes = []
+ model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
+ for model_mapping_name in model_mappings:
+ model_mapping = getattr(modules[framework], model_mapping_name)
+ if model_type in model_mapping:
+ new_model_classes.append(model_mapping[model_type])
+
+ if len(new_model_classes) > 0:
+ # Remove duplicates
+ model_classes[framework] = list(set(new_model_classes))
+
+ return model_classes
+
+
+def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
+ """
+ Retrieves all the information from a given model_type.
+
+ Args:
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
+ frameworks (`List[str]`, *optional*):
+ If passed, will only keep the info corresponding to the passed frameworks.
+
+ Returns:
+ `Dict`: A dictionary with the following keys:
+ - **frameworks** (`List[str]`): The list of frameworks that back this model type.
+ - **model_classes** (`Dict[str, List[str]]`): The model classes implemented for that model type.
+ - **model_files** (`Dict[str, Union[Path, List[Path]]]`): The files associated with that model type.
+ - **model_patterns** (`ModelPatterns`): The various patterns for the model.
+ """
+ if model_type not in auto_module.MODEL_NAMES_MAPPING:
+ raise ValueError(f"{model_type} is not a valid model type.")
+
+ model_name = auto_module.MODEL_NAMES_MAPPING[model_type]
+ config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type]
+ archive_map = auto_module.configuration_auto.CONFIG_ARCHIVE_MAP_MAPPING_NAMES.get(model_type, None)
+ if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES:
+ tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type]
+ tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
+ else:
+ tokenizer_class = None
+ feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
+ processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
+
+ model_files = get_model_files(model_type, frameworks=frameworks)
+ model_camel_cased = config_class.replace("Config", "")
+
+ available_frameworks = []
+ for fname in model_files["model_files"]:
+ if "modeling_tf" in str(fname):
+ available_frameworks.append("tf")
+ elif "modeling_flax" in str(fname):
+ available_frameworks.append("flax")
+ elif "modeling" in str(fname):
+ available_frameworks.append("pt")
+
+ if frameworks is None:
+ frameworks = available_frameworks.copy()
+ else:
+ frameworks = [f for f in frameworks if f in available_frameworks]
+
+ model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
+
+ # Retrieve model upper-cased name from the constant name of the pretrained archive map.
+ if archive_map is None:
+ model_upper_cased = model_camel_cased.upper()
+ else:
+ parts = archive_map.split("_")
+ idx = 0
+ while idx < len(parts) and parts[idx] != "PRETRAINED":
+ idx += 1
+ if idx < len(parts):
+ model_upper_cased = "_".join(parts[:idx])
+ else:
+ model_upper_cased = model_camel_cased.upper()
+
+ model_patterns = ModelPatterns(
+ model_name,
+ checkpoint=find_base_model_checkpoint(model_type, model_files=model_files),
+ model_type=model_type,
+ model_camel_cased=model_camel_cased,
+ model_lower_cased=model_files["module_name"],
+ model_upper_cased=model_upper_cased,
+ config_class=config_class,
+ tokenizer_class=tokenizer_class,
+ feature_extractor_class=feature_extractor_class,
+ processor_class=processor_class,
+ )
+
+ return {
+ "frameworks": frameworks,
+ "model_classes": model_classes,
+ "model_files": model_files,
+ "model_patterns": model_patterns,
+ }
+
+
+def clean_frameworks_in_init(
+ init_file: Union[str, os.PathLike], frameworks: Optional[List[str]] = None, keep_processing: bool = True
+):
+ """
+ Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature
+ extractors/processors in an init.
+
+ Args:
+ init_file (`str` or `os.PathLike`): The path to the init to treat.
+ frameworks (`List[str]`, *optional*):
+ If passed, this will remove all imports that are subject to a framework not in frameworks
+ keep_processing (`bool`, *optional*, defaults to `True`):
+ Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init.
+ """
+ if frameworks is None:
+ frameworks = ["pt", "tf", "flax"]
+
+ names = {"pt": "torch"}
+ to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
+ if not keep_processing:
+ to_remove.extend(["sentencepiece", "tokenizers", "vision"])
+
+ if len(to_remove) == 0:
+ # Nothing to do
+ return
+
+ remove_pattern = "|".join(to_remove)
+ re_conditional_imports = re.compile(fr"^\s*if is_({remove_pattern})_available\(\):\s*$")
+ re_is_xxx_available = re.compile(fr"is_({remove_pattern})_available")
+
+ with open(init_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ lines = content.split("\n")
+ new_lines = []
+ idx = 0
+ while idx < len(lines):
+ # Conditional imports
+ if re_conditional_imports.search(lines[idx]) is not None:
+ idx += 1
+ while is_empty_line(lines[idx]):
+ idx += 1
+ indent = find_indent(lines[idx])
+ while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
+ idx += 1
+ # Remove the import from file_utils
+ elif re_is_xxx_available.search(lines[idx]) is not None:
+ line = lines[idx]
+ for framework in to_remove:
+ line = line.replace(f", is_{framework}_available", "")
+ line = line.replace(f"is_{framework}_available, ", "")
+ line = line.replace(f"is_{framework}_available", "")
+
+ if len(line.strip()) > 0:
+ new_lines.append(line)
+ idx += 1
+ # Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it.
+ elif keep_processing or (
+ re.search('^\s*"(tokenization|processing|feature_extraction)', lines[idx]) is None
+ and re.search("^\s*from .(tokenization|processing|feature_extraction)", lines[idx]) is None
+ ):
+ new_lines.append(lines[idx])
+ idx += 1
+ else:
+ idx += 1
+
+ with open(init_file, "w", encoding="utf-8") as f:
+ f.write("\n".join(new_lines))
+
+
+def add_model_to_main_init(
+ old_model_patterns: ModelPatterns,
+ new_model_patterns: ModelPatterns,
+ frameworks: Optional[List[str]] = None,
+ with_processing: bool = True,
+):
+ """
+ Add a model to the main init of Transformers.
+
+ Args:
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+ frameworks (`List[str]`, *optional*):
+ If specified, only the models implemented in those frameworks will be added.
+ with_processsing (`bool`, *optional*, defaults to `True`):
+ Whether the tokenizer/feature extractor/processor of the model should also be added to the init or not.
+ """
+ with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
+ content = f.read()
+
+ lines = content.split("\n")
+ idx = 0
+ new_lines = []
+ framework = None
+ while idx < len(lines):
+ if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
+ framework = None
+ elif lines[idx].lstrip().startswith("if is_torch_available"):
+ framework = "pt"
+ elif lines[idx].lstrip().startswith("if is_tf_available"):
+ framework = "tf"
+ elif lines[idx].lstrip().startswith("if is_flax_available"):
+ framework = "flax"
+
+ # Skip if we are in a framework not wanted.
+ if framework is not None and frameworks is not None and framework not in frameworks:
+ new_lines.append(lines[idx])
+ idx += 1
+ elif re.search(fr'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
+ block = [lines[idx]]
+ indent = find_indent(lines[idx])
+ idx += 1
+ while find_indent(lines[idx]) > indent:
+ block.append(lines[idx])
+ idx += 1
+ if lines[idx].strip() in [")", "]", "],"]:
+ block.append(lines[idx])
+ idx += 1
+ block = "\n".join(block)
+ new_lines.append(block)
+
+ add_block = True
+ if not with_processing:
+ processing_classes = [
+ old_model_patterns.tokenizer_class,
+ old_model_patterns.feature_extractor_class,
+ old_model_patterns.processor_class,
+ ]
+ # Only keep the ones that are not None
+ processing_classes = [c for c in processing_classes if c is not None]
+ for processing_class in processing_classes:
+ block = block.replace(f' "{processing_class}",', "")
+ block = block.replace(f', "{processing_class}"', "")
+ block = block.replace(f" {processing_class},", "")
+ block = block.replace(f", {processing_class}", "")
+
+ if processing_class in block:
+ add_block = False
+ if add_block:
+ new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0])
+ else:
+ new_lines.append(lines[idx])
+ idx += 1
+
+ with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
+ f.write("\n".join(new_lines))
+
+
+def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns):
+ """
+ Add a tokenizer to the relevant mappings in the auto module.
+
+ Args:
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+ """
+ if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None:
+ return
+
+ with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
+ content = f.read()
+
+ lines = content.split("\n")
+ idx = 0
+ # First we get to the TOKENIZER_MAPPING_NAMES block.
+ while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("):
+ idx += 1
+ idx += 1
+
+ # That block will end at this prompt:
+ while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"):
+ # Either all the tokenizer block is defined on one line, in which case, it ends with "),"
+ if lines[idx].endswith(","):
+ block = lines[idx]
+ # Otherwise it takes several lines until we get to a "),"
+ else:
+ block = []
+ while not lines[idx].startswith(" ),"):
+ block.append(lines[idx])
+ idx += 1
+ block = "\n".join(block)
+ idx += 1
+
+ # If we find the model type and tokenizer class in that block, we have the old model tokenizer block
+ if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block:
+ break
+
+ new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type)
+ new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class)
+
+ new_lines = lines[:idx] + [new_block] + lines[idx:]
+ with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f:
+ f.write("\n".join(new_lines))
+
+
+AUTO_CLASSES_PATTERNS = {
+ "configuration_auto.py": [
+ ' ("{model_type}", "{model_name}"),',
+ ' ("{model_type}", "{config_class}"),',
+ ' ("{model_type}", "{pretrained_archive_map}"),',
+ ],
+ "feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
+ "modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
+ "modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
+ "modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
+ "processing_auto.py": [' ("{model_type}", "{processor_class}"),'],
+}
+
+
+def add_model_to_auto_classes(
+ old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: Dict[str, List[str]]
+):
+ """
+ Add a model to the relevant mappings in the auto module.
+
+ Args:
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+ model_classes (`Dict[str, List[str]]`): A dictionary framework to list of model classes implemented.
+ """
+ for filename in AUTO_CLASSES_PATTERNS:
+ # Extend patterns with all model classes if necessary
+ new_patterns = []
+ for pattern in AUTO_CLASSES_PATTERNS[filename]:
+ if re.search("any_([a-z]*)_class", pattern) is not None:
+ framework = re.search("any_([a-z]*)_class", pattern).groups()[0]
+ if framework in model_classes:
+ new_patterns.extend(
+ [
+ pattern.replace("{" + f"any_{framework}_class" + "}", cls)
+ for cls in model_classes[framework]
+ ]
+ )
+ elif "{config_class}" in pattern:
+ new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class))
+ elif "{feature_extractor_class}" in pattern:
+ if (
+ old_model_patterns.feature_extractor_class is not None
+ and new_model_patterns.feature_extractor_class is not None
+ ):
+ new_patterns.append(
+ pattern.replace("{feature_extractor_class}", old_model_patterns.feature_extractor_class)
+ )
+ elif "{processor_class}" in pattern:
+ if old_model_patterns.processor_class is not None and new_model_patterns.processor_class is not None:
+ new_patterns.append(pattern.replace("{processor_class}", old_model_patterns.processor_class))
+ else:
+ new_patterns.append(pattern)
+
+ # Loop through all patterns.
+ for pattern in new_patterns:
+ full_name = TRANSFORMERS_PATH / "models" / "auto" / filename
+ old_model_line = pattern
+ new_model_line = pattern
+ for attr in ["model_type", "model_name"]:
+ old_model_line = old_model_line.replace("{" + attr + "}", getattr(old_model_patterns, attr))
+ new_model_line = new_model_line.replace("{" + attr + "}", getattr(new_model_patterns, attr))
+ if "pretrained_archive_map" in pattern:
+ old_model_line = old_model_line.replace(
+ "{pretrained_archive_map}", f"{old_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP"
+ )
+ new_model_line = new_model_line.replace(
+ "{pretrained_archive_map}", f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP"
+ )
+
+ new_model_line = new_model_line.replace(
+ old_model_patterns.model_camel_cased, new_model_patterns.model_camel_cased
+ )
+
+ add_content_to_file(full_name, new_model_line, add_after=old_model_line)
+
+ # Tokenizers require special handling
+ insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns)
+
+
+DOC_OVERVIEW_TEMPLATE = """## Overview
+
+The {model_name} model was proposed in []() by .
+
+
+The abstract from the paper is the following:
+
+**
+
+Tips:
+
+
+
+This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/).
+The original code can be found [here]().
+
+"""
+
+
+def duplicate_doc_file(
+ doc_file: Union[str, os.PathLike],
+ old_model_patterns: ModelPatterns,
+ new_model_patterns: ModelPatterns,
+ dest_file: Optional[Union[str, os.PathLike]] = None,
+ frameworks: Optional[List[str]] = None,
+):
+ """
+ Duplicate a documentation file and adapts it for a new model.
+
+ Args:
+ module_file (`str` or `os.PathLike`): Path to the doc file to duplicate.
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+ dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file.
+ Will default to the a file named `{new_model_patterns.model_type}.mdx` in the same folder as `module_file`.
+ frameworks (`List[str]`, *optional*):
+ If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file.
+ """
+ with open(doc_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ if frameworks is None:
+ frameworks = ["pt", "tf", "flax"]
+ if dest_file is None:
+ dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx"
+
+ # Parse the doc file in blocks. One block per section/header
+ lines = content.split("\n")
+ blocks = []
+ current_block = []
+
+ for line in lines:
+ if line.startswith("#"):
+ blocks.append("\n".join(current_block))
+ current_block = [line]
+ else:
+ current_block.append(line)
+ blocks.append("\n".join(current_block))
+
+ new_blocks = []
+ in_classes = False
+ for block in blocks:
+ # Copyright
+ if not block.startswith("#"):
+ new_blocks.append(block)
+ # Main title
+ elif re.search("^#\s+\S+", block) is not None:
+ new_blocks.append(f"# {new_model_patterns.model_name}\n")
+ # The config starts the part of the doc with the classes.
+ elif not in_classes and old_model_patterns.config_class in block.split("\n")[0]:
+ in_classes = True
+ new_blocks.append(DOC_OVERVIEW_TEMPLATE.format(model_name=new_model_patterns.model_name))
+ new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
+ new_blocks.append(new_block)
+ # In classes
+ elif in_classes:
+ in_classes = True
+ block_title = block.split("\n")[0]
+ block_class = re.search("^#+\s+(\S.*)$", block_title).groups()[0]
+ new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
+
+ if "Tokenizer" in block_class:
+ # We only add the tokenizer if necessary
+ if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
+ new_blocks.append(new_block)
+ elif "FeatureExtractor" in block_class:
+ # We only add the feature extractor if necessary
+ if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
+ new_blocks.append(new_block)
+ elif "Processor" in block_class:
+ # We only add the processor if necessary
+ if old_model_patterns.processor_class != new_model_patterns.processor_class:
+ new_blocks.append(new_block)
+ elif block_class.startswith("Flax"):
+ # We only add Flax models if in the selected frameworks
+ if "flax" in frameworks:
+ new_blocks.append(new_block)
+ elif block_class.startswith("TF"):
+ # We only add TF models if in the selected frameworks
+ if "tf" in frameworks:
+ new_blocks.append(new_block)
+ elif len(block_class.split(" ")) == 1:
+ # We only add PyTorch models if in the selected frameworks
+ if "pt" in frameworks:
+ new_blocks.append(new_block)
+ else:
+ new_blocks.append(new_block)
+
+ with open(dest_file, "w", encoding="utf-8") as f:
+ f.write("\n".join(new_blocks))
+
+
+def create_new_model_like(
+ model_type: str,
+ new_model_patterns: ModelPatterns,
+ add_copied_from: bool = True,
+ frameworks: Optional[List[str]] = None,
+):
+ """
+ Creates a new model module like a given model of the Transformers library.
+
+ Args:
+ model_type (`str`): The model type to duplicate (like "bert" or "gpt2")
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
+ add_copied_from (`bool`, *optional*, defaults to `True`):
+ Whether or not to add "Copied from" statements to all classes in the new model modeling files.
+ frameworks (`List[str]`, *optional*):
+ If passed, will limit the duplicate to the frameworks specified.
+ """
+ # Retrieve all the old model info.
+ model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
+ model_files = model_info["model_files"]
+ old_model_patterns = model_info["model_patterns"]
+ keep_old_processing = True
+ for processing_attr in ["feature_extractor_class", "processor_class", "tokenizer_class"]:
+ if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
+ keep_old_processing = False
+
+ model_classes = model_info["model_classes"]
+
+ # 1. We create the module for our new model.
+ old_module_name = model_files["module_name"]
+ module_folder = TRANSFORMERS_PATH / "models" / new_model_patterns.model_lower_cased
+ os.makedirs(module_folder, exist_ok=True)
+
+ files_to_adapt = model_files["model_files"]
+ if keep_old_processing:
+ files_to_adapt = [
+ f
+ for f in files_to_adapt
+ if "tokenization" not in str(f) and "processing" not in str(f) and "feature_extraction" not in str(f)
+ ]
+
+ os.makedirs(module_folder, exist_ok=True)
+ for module_file in files_to_adapt:
+ new_module_name = module_file.name.replace(
+ old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
+ )
+ dest_file = module_folder / new_module_name
+ duplicate_module(
+ module_file,
+ old_model_patterns,
+ new_model_patterns,
+ dest_file=dest_file,
+ add_copied_from=add_copied_from and "modeling" in new_module_name,
+ )
+
+ clean_frameworks_in_init(
+ module_folder / "__init__.py", frameworks=frameworks, keep_processing=not keep_old_processing
+ )
+
+ # 2. We add our new model to the models init and the main init
+ add_content_to_file(
+ TRANSFORMERS_PATH / "models" / "__init__.py",
+ f" {new_model_patterns.model_lower_cased},",
+ add_after=f" {old_module_name},",
+ exact_match=True,
+ )
+ add_model_to_main_init(
+ old_model_patterns, new_model_patterns, frameworks=frameworks, with_processing=not keep_old_processing
+ )
+
+ # 3. Add test files
+ files_to_adapt = model_files["test_files"]
+ if keep_old_processing:
+ files_to_adapt = [
+ f
+ for f in files_to_adapt
+ if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
+ ]
+
+ for test_file in files_to_adapt:
+ new_test_file_name = test_file.name.replace(
+ old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
+ )
+ dest_file = test_file.parent / new_test_file_name
+ duplicate_module(
+ test_file,
+ old_model_patterns,
+ new_model_patterns,
+ dest_file=dest_file,
+ add_copied_from=False,
+ )
+
+ # 4. Add model to auto classes
+ add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
+
+ # 5. Add doc file
+ doc_file = REPO_PATH / "docs" / "source" / "model_doc" / f"{old_model_patterns.model_type}.mdx"
+ duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
+
+ # 6. Warn the user for duplicate patterns
+ if old_model_patterns.model_type == old_model_patterns.checkpoint:
+ print(
+ "The model you picked has the same name for the model type and the checkpoint name "
+ f"({old_model_patterns.model_type}). As a result, it's possible some places where the new checkpoint "
+ f"should be, you have {new_model_patterns.model_type} instead. You should search for all instances of "
+ f"{new_model_patterns.model_type} in the new files and check they're not badly used as checkpoints."
+ )
+ elif old_model_patterns.model_lower_cased == old_model_patterns.checkpoint:
+ print(
+ "The model you picked has the same name for the model type and the checkpoint name "
+ f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
+ f"checkpoint should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
+ f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
+ "used as checkpoints."
+ )
+ if (
+ old_model_patterns.model_type == old_model_patterns.model_lower_cased
+ and new_model_patterns.model_type != new_model_patterns.model_lower_cased
+ ):
+ print(
+ "The model you picked has the same name for the model type and the lowercased model name "
+ f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
+ f"model type should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
+ f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
+ "used as the model type."
+ )
+
+ if not keep_old_processing and old_model_patterns.tokenizer_class is not None:
+ print(
+ "The constants at the start of the new tokenizer file created needs to be manually fixed. If your new "
+ "model has a tokenizer fast, you will also need to manually add the converter in the "
+ "`SLOW_TO_FAST_CONVERTERS` constant of `convert_slow_tokenizer.py`."
+ )
+
+
+def add_new_model_like_command_factory(args: Namespace):
+ return AddNewModelLikeCommand(config_file=args.config_file, path_to_repo=args.path_to_repo)
+
+
+class AddNewModelLikeCommand(BaseTransformersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ add_new_model_like_parser = parser.add_parser("add-new-model-like")
+ add_new_model_like_parser.add_argument(
+ "--config_file", type=str, help="A file with all the information for this model creation."
+ )
+ add_new_model_like_parser.add_argument(
+ "--path_to_repo", type=str, help="When not using an editable install, the path to the Transformers repo."
+ )
+ add_new_model_like_parser.set_defaults(func=add_new_model_like_command_factory)
+
+ def __init__(self, config_file=None, path_to_repo=None, *args):
+ if config_file is not None:
+ with open(config_file, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ self.old_model_type = config["old_model_type"]
+ self.model_patterns = ModelPatterns(**config["new_model_patterns"])
+ self.add_copied_from = config.get("add_copied_from", True)
+ self.frameworks = config.get("frameworks", ["pt", "tf", "flax"])
+ else:
+ self.old_model_type, self.model_patterns, self.add_copied_from, self.frameworks = get_user_input()
+
+ self.path_to_repo = path_to_repo
+
+ def run(self):
+ if self.path_to_repo is not None:
+ # Adapt constants
+ global TRANSFORMERS_PATH
+ global REPO_PATH
+
+ REPO_PATH = Path(self.path_to_repo)
+ TRANSFORMERS_PATH = REPO_PATH / "src" / "transformers"
+
+ create_new_model_like(
+ model_type=self.old_model_type,
+ new_model_patterns=self.model_patterns,
+ add_copied_from=self.add_copied_from,
+ frameworks=self.frameworks,
+ )
+
+
+def get_user_field(
+ question: str,
+ default_value: Optional[str] = None,
+ is_valid_answer: Optional[Callable] = None,
+ convert_to: Optional[Callable] = None,
+ fallback_message: Optional[str] = None,
+) -> Any:
+ """
+ A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
+ answer.
+
+ Args:
+ question (`str`): The question to ask the user.
+ default_value (`str`, *optional*): A potential default value that will be used when the answer is empty.
+ is_valid_answer (`Callable`, *optional*):
+ If set, the question will be asked until this function returns `True` on the provided answer.
+ convert_to (`Callable`, *optional*):
+ If set, the answer will be passed to this function. If this function raises an error on the procided
+ answer, the question will be asked again.
+ fallback_message (`str`, *optional*):
+ A message that will be displayed each time the question is asked again to the user.
+
+ Returns:
+ `Any`: The answer provided by the user (or the default), passed through the potential conversion function.
+ """
+ if not question.endswith(" "):
+ question = question + " "
+ if default_value is not None:
+ question = f"{question} [{default_value}] "
+
+ valid_answer = False
+ while not valid_answer:
+ answer = input(question)
+ if default_value is not None and len(answer) == 0:
+ answer = default_value
+ if is_valid_answer is not None:
+ valid_answer = is_valid_answer(answer)
+ elif convert_to is not None:
+ try:
+ answer = convert_to(answer)
+ valid_answer = True
+ except Exception:
+ valid_answer = False
+ else:
+ valid_answer = True
+
+ if not valid_answer:
+ print(fallback_message)
+
+ return answer
+
+
+def convert_to_bool(x: str) -> bool:
+ """
+ Converts a string to a bool.
+ """
+ if x.lower() in ["1", "y", "yes", "true"]:
+ return True
+ if x.lower() in ["0", "n", "no", "false"]:
+ return False
+ raise ValueError(f"{x} is not a value that can be converted to a bool.")
+
+
+def get_user_input():
+ """
+ Ask the user for the necessary inputs to add the new model.
+ """
+ model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys())
+
+ # Get old model type
+ valid_model_type = False
+ while not valid_model_type:
+ old_model_type = input("What is the model you would like to duplicate? ")
+ if old_model_type in model_types:
+ valid_model_type = True
+ else:
+ print(f"{old_model_type} is not a valid model type.")
+ near_choices = difflib.get_close_matches(old_model_type, model_types)
+ if len(near_choices) >= 1:
+ if len(near_choices) > 1:
+ near_choices = " or ".join(near_choices)
+ print(f"Did you mean {near_choices}?")
+
+ old_model_info = retrieve_info_for_model(old_model_type)
+ old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
+ old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
+ old_processor_class = old_model_info["model_patterns"].processor_class
+ old_frameworks = old_model_info["frameworks"]
+
+ model_name = get_user_field("What is the name for your new model?")
+ default_patterns = ModelPatterns(model_name, model_name)
+
+ model_type = get_user_field(
+ "What identifier would you like to use for the model type of this model?",
+ default_value=default_patterns.model_type,
+ )
+ model_lower_cased = get_user_field(
+ "What name would you like to use for the module of this model?",
+ default_value=default_patterns.model_lower_cased,
+ )
+ model_camel_cased = get_user_field(
+ "What prefix (camel-cased) would you like to use for the model classes of this model?",
+ default_value=default_patterns.model_camel_cased,
+ )
+ model_upper_cased = get_user_field(
+ "What prefix (upper-cased) would you like to use for the constants relative to this model?",
+ default_value=default_patterns.model_upper_cased,
+ )
+ config_class = get_user_field(
+ "What will be the name of the config class for this model?", default_value=f"{model_camel_cased}Config"
+ )
+ checkpoint = get_user_field("Please give a checkpoint identifier (on the model Hub) for this new model.")
+
+ old_processing_classes = [
+ c for c in [old_feature_extractor_class, old_tokenizer_class, old_processor_class] if c is not None
+ ]
+ old_processing_classes = ", ".join(old_processing_classes)
+ keep_processing = get_user_field(
+ f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes})?",
+ convert_to=convert_to_bool,
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
+ )
+ if keep_processing:
+ feature_extractor_class = old_feature_extractor_class
+ processor_class = old_processor_class
+ tokenizer_class = old_tokenizer_class
+ else:
+ if old_tokenizer_class is not None:
+ tokenizer_class = get_user_field(
+ "What will be the name of the tokenizer class for this model?",
+ default_value=f"{model_camel_cased}Tokenizer",
+ )
+ else:
+ tokenizer_class = None
+ if old_feature_extractor_class is not None:
+ feature_extractor_class = get_user_field(
+ "What will be the name of the feature extractor class for this model?",
+ default_value=f"{model_camel_cased}FeatureExtractor",
+ )
+ else:
+ feature_extractor_class = None
+ if old_processor_class is not None:
+ processor_class = get_user_field(
+ "What will be the name of the processor class for this model?",
+ default_value=f"{model_camel_cased}Processor",
+ )
+ else:
+ processor_class = None
+
+ model_patterns = ModelPatterns(
+ model_name,
+ checkpoint,
+ model_type=model_type,
+ model_lower_cased=model_lower_cased,
+ model_camel_cased=model_camel_cased,
+ model_upper_cased=model_upper_cased,
+ config_class=config_class,
+ tokenizer_class=tokenizer_class,
+ feature_extractor_class=feature_extractor_class,
+ processor_class=processor_class,
+ )
+
+ add_copied_from = get_user_field(
+ "Should we add # Copied from statements when creating the new modeling file?",
+ convert_to=convert_to_bool,
+ default_value="yes",
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
+ )
+
+ all_frameworks = get_user_field(
+ f"Should we add a version of your new model in all the frameworks implemented by {old_model_type} ({old_frameworks})?",
+ convert_to=convert_to_bool,
+ default_value="yes",
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
+ )
+ if all_frameworks:
+ frameworks = None
+ else:
+ frameworks = get_user_field(
+ "Please enter the list of framworks you want (pt, tf, flax) separated by spaces",
+ is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")),
+ )
+ frameworks = list(set(frameworks.split(" ")))
+
+ return (old_model_type, model_patterns, add_copied_from, frameworks)
diff --git a/src/transformers/commands/transformers_cli.py b/src/transformers/commands/transformers_cli.py
index d63f6bc9c6e..66c81bb9d09 100644
--- a/src/transformers/commands/transformers_cli.py
+++ b/src/transformers/commands/transformers_cli.py
@@ -16,6 +16,7 @@
from argparse import ArgumentParser
from .add_new_model import AddNewModelCommand
+from .add_new_model_like import AddNewModelLikeCommand
from .convert import ConvertCommand
from .download import DownloadCommand
from .env import EnvironmentCommand
@@ -37,6 +38,7 @@ def main():
ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
AddNewModelCommand.register_subcommand(commands_parser)
+ AddNewModelLikeCommand.register_subcommand(commands_parser)
LfsCommands.register_subcommand(commands_parser)
# Let's go
diff --git a/tests/fixtures/add_distilbert_like_config.json b/tests/fixtures/add_distilbert_like_config.json
new file mode 100644
index 00000000000..812d2a635dd
--- /dev/null
+++ b/tests/fixtures/add_distilbert_like_config.json
@@ -0,0 +1,19 @@
+{
+ "add_copied_from": true,
+ "old_model_type": "distilbert",
+ "new_model_patterns": {
+ "model_name": "BERT New",
+ "checkpoint": "huggingface/bert-new-base",
+ "model_type": "bert-new",
+ "model_lower_cased": "bert_new",
+ "model_camel_cased": "BertNew",
+ "model_upper_cased": "BERT_NEW",
+ "config_class": "BertNewConfig",
+ "tokenizer_class": "DistilBertTokenizer"
+ },
+ "frameworks": [
+ "pt",
+ "tf",
+ "flax"
+ ]
+}
\ No newline at end of file
diff --git a/tests/test_add_new_model_like.py b/tests/test_add_new_model_like.py
new file mode 100644
index 00000000000..d1134bad521
--- /dev/null
+++ b/tests/test_add_new_model_like.py
@@ -0,0 +1,1342 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+import tempfile
+import unittest
+from pathlib import Path
+
+import transformers
+from transformers.commands.add_new_model_like import (
+ ModelPatterns,
+ _re_class_func,
+ add_content_to_file,
+ add_content_to_text,
+ clean_frameworks_in_init,
+ duplicate_doc_file,
+ duplicate_module,
+ filter_framework_files,
+ find_base_model_checkpoint,
+ get_model_files,
+ get_module_from_file,
+ parse_module_content,
+ replace_model_patterns,
+ retrieve_info_for_model,
+ retrieve_model_classes,
+ simplify_replacements,
+)
+from transformers.testing_utils import require_flax, require_tf, require_torch
+
+
+BERT_MODEL_FILES = {
+ "src/transformers/models/bert/__init__.py",
+ "src/transformers/models/bert/configuration_bert.py",
+ "src/transformers/models/bert/tokenization_bert.py",
+ "src/transformers/models/bert/tokenization_bert_fast.py",
+ "src/transformers/models/bert/modeling_bert.py",
+ "src/transformers/models/bert/modeling_flax_bert.py",
+ "src/transformers/models/bert/modeling_tf_bert.py",
+ "src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py",
+ "src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py",
+ "src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py",
+}
+
+VIT_MODEL_FILES = {
+ "src/transformers/models/vit/__init__.py",
+ "src/transformers/models/vit/configuration_vit.py",
+ "src/transformers/models/vit/convert_dino_to_pytorch.py",
+ "src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
+ "src/transformers/models/vit/feature_extraction_vit.py",
+ "src/transformers/models/vit/modeling_vit.py",
+ "src/transformers/models/vit/modeling_tf_vit.py",
+ "src/transformers/models/vit/modeling_flax_vit.py",
+}
+
+WAV2VEC2_MODEL_FILES = {
+ "src/transformers/models/wav2vec2/__init__.py",
+ "src/transformers/models/wav2vec2/configuration_wav2vec2.py",
+ "src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py",
+ "src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py",
+ "src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py",
+ "src/transformers/models/wav2vec2/modeling_wav2vec2.py",
+ "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py",
+ "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py",
+ "src/transformers/models/wav2vec2/processing_wav2vec2.py",
+ "src/transformers/models/wav2vec2/tokenization_wav2vec2.py",
+}
+
+REPO_PATH = Path(transformers.__path__[0]).parent.parent
+
+
+@require_torch
+@require_tf
+@require_flax
+class TestAddNewModelLike(unittest.TestCase):
+ def init_file(self, file_name, content):
+ with open(file_name, "w", encoding="utf-8") as f:
+ f.write(content)
+
+ def check_result(self, file_name, expected_result):
+ with open(file_name, "r", encoding="utf-8") as f:
+ self.assertEqual(f.read(), expected_result)
+
+ def test_re_class_func(self):
+ self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function")
+ self.assertEqual(_re_class_func.search("class MyClass:").groups()[0], "MyClass")
+ self.assertEqual(_re_class_func.search("class MyClass(SuperClass):").groups()[0], "MyClass")
+
+ def test_model_patterns_defaults(self):
+ model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base")
+
+ self.assertEqual(model_patterns.model_type, "gpt-new-new")
+ self.assertEqual(model_patterns.model_lower_cased, "gpt_new_new")
+ self.assertEqual(model_patterns.model_camel_cased, "GPTNewNew")
+ self.assertEqual(model_patterns.model_upper_cased, "GPT_NEW_NEW")
+ self.assertEqual(model_patterns.config_class, "GPTNewNewConfig")
+ self.assertIsNone(model_patterns.tokenizer_class)
+ self.assertIsNone(model_patterns.feature_extractor_class)
+ self.assertIsNone(model_patterns.processor_class)
+
+ def test_parse_module_content(self):
+ test_code = """SOME_CONSTANT = a constant
+
+CONSTANT_DEFINED_ON_SEVERAL_LINES = [
+ first_item,
+ second_item
+]
+
+def function(args):
+ some code
+
+# Copied from transformers.some_module
+class SomeClass:
+ some code
+"""
+
+ expected_parts = [
+ "SOME_CONSTANT = a constant\n",
+ "CONSTANT_DEFINED_ON_SEVERAL_LINES = [\n first_item,\n second_item\n]",
+ "",
+ "def function(args):\n some code\n",
+ "# Copied from transformers.some_module\nclass SomeClass:\n some code\n",
+ ]
+ self.assertEqual(parse_module_content(test_code), expected_parts)
+
+ def test_add_content_to_text(self):
+ test_text = """all_configs = {
+ "gpt": "GPTConfig",
+ "bert": "BertConfig",
+ "t5": "T5Config",
+}"""
+
+ expected = """all_configs = {
+ "gpt": "GPTConfig",
+ "gpt2": "GPT2Config",
+ "bert": "BertConfig",
+ "t5": "T5Config",
+}"""
+ line = ' "gpt2": "GPT2Config",'
+
+ self.assertEqual(add_content_to_text(test_text, line, add_before="bert"), expected)
+ self.assertEqual(add_content_to_text(test_text, line, add_before="bert", exact_match=True), test_text)
+ self.assertEqual(
+ add_content_to_text(test_text, line, add_before=' "bert": "BertConfig",', exact_match=True), expected
+ )
+ self.assertEqual(add_content_to_text(test_text, line, add_before=re.compile('^\s*"bert":')), expected)
+
+ self.assertEqual(add_content_to_text(test_text, line, add_after="gpt"), expected)
+ self.assertEqual(add_content_to_text(test_text, line, add_after="gpt", exact_match=True), test_text)
+ self.assertEqual(
+ add_content_to_text(test_text, line, add_after=' "gpt": "GPTConfig",', exact_match=True), expected
+ )
+ self.assertEqual(add_content_to_text(test_text, line, add_after=re.compile('^\s*"gpt":')), expected)
+
+ def test_add_content_to_file(self):
+ test_text = """all_configs = {
+ "gpt": "GPTConfig",
+ "bert": "BertConfig",
+ "t5": "T5Config",
+}"""
+
+ expected = """all_configs = {
+ "gpt": "GPTConfig",
+ "gpt2": "GPT2Config",
+ "bert": "BertConfig",
+ "t5": "T5Config",
+}"""
+ line = ' "gpt2": "GPT2Config",'
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file_name = os.path.join(tmp_dir, "code.py")
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_before="bert")
+ self.check_result(file_name, expected)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_before="bert", exact_match=True)
+ self.check_result(file_name, test_text)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_before=' "bert": "BertConfig",', exact_match=True)
+ self.check_result(file_name, expected)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_before=re.compile('^\s*"bert":'))
+ self.check_result(file_name, expected)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_after="gpt")
+ self.check_result(file_name, expected)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_after="gpt", exact_match=True)
+ self.check_result(file_name, test_text)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_after=' "gpt": "GPTConfig",', exact_match=True)
+ self.check_result(file_name, expected)
+
+ self.init_file(file_name, test_text)
+ add_content_to_file(file_name, line, add_after=re.compile('^\s*"gpt":'))
+ self.check_result(file_name, expected)
+
+ def test_simplify_replacements(self):
+ self.assertEqual(simplify_replacements([("Bert", "NewBert")]), [("Bert", "NewBert")])
+ self.assertEqual(
+ simplify_replacements([("Bert", "NewBert"), ("bert", "new-bert")]),
+ [("Bert", "NewBert"), ("bert", "new-bert")],
+ )
+ self.assertEqual(
+ simplify_replacements([("BertConfig", "NewBertConfig"), ("Bert", "NewBert"), ("bert", "new-bert")]),
+ [("Bert", "NewBert"), ("bert", "new-bert")],
+ )
+
+ def test_replace_model_patterns(self):
+ bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
+ new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
+ bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+ model_type = "bert"
+
+BERT_CONSTANT = "value"
+'''
+ bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = NewBertConfig
+ load_tf_weights = load_tf_weights_in_new_bert
+ base_model_prefix = "new_bert"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+ model_type = "new-bert"
+
+NEW_BERT_CONSTANT = "value"
+'''
+
+ bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns)
+ self.assertEqual(bert_converted, bert_expected)
+ # Replacements are empty here since bert as been replaced by bert_new in some instances and bert-new
+ # in others.
+ self.assertEqual(replacements, "")
+
+ # If we remove the model type, we will get replacements
+ bert_test = bert_test.replace(' model_type = "bert"\n', "")
+ bert_expected = bert_expected.replace(' model_type = "new-bert"\n', "")
+ bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns)
+ self.assertEqual(bert_converted, bert_expected)
+ self.assertEqual(replacements, "BERT->NEW_BERT,Bert->NewBert,bert->new_bert")
+
+ gpt_model_patterns = ModelPatterns("GPT2", "gpt2")
+ new_gpt_model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base")
+ gpt_test = '''class GPT2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPT2Config
+ load_tf_weights = load_tf_weights_in_gpt2
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+GPT2_CONSTANT = "value"
+'''
+
+ gpt_expected = '''class GPTNewNewPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPTNewNewConfig
+ load_tf_weights = load_tf_weights_in_gpt_new_new
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+GPT_NEW_NEW_CONSTANT = "value"
+'''
+
+ gpt_converted, replacements = replace_model_patterns(gpt_test, gpt_model_patterns, new_gpt_model_patterns)
+ self.assertEqual(gpt_converted, gpt_expected)
+ # Replacements are empty here since GPT2 as been replaced by GPTNewNew in some instances and GPT_NEW_NEW
+ # in others.
+ self.assertEqual(replacements, "")
+
+ roberta_model_patterns = ModelPatterns("RoBERTa", "roberta-base", model_camel_cased="Roberta")
+ new_roberta_model_patterns = ModelPatterns(
+ "RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew"
+ )
+ roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta
+class RobertaModel(RobertaPreTrainedModel):
+ """ The base RoBERTa model. """
+ checkpoint = roberta-base
+ base_model_prefix = "roberta"
+ '''
+ roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew
+class RobertaNewModel(RobertaNewPreTrainedModel):
+ """ The base RoBERTa-New model. """
+ checkpoint = huggingface/roberta-new-base
+ base_model_prefix = "roberta_new"
+ '''
+ roberta_converted, replacements = replace_model_patterns(
+ roberta_test, roberta_model_patterns, new_roberta_model_patterns
+ )
+ self.assertEqual(roberta_converted, roberta_expected)
+
+ def test_get_module_from_file(self):
+ self.assertEqual(
+ get_module_from_file("/git/transformers/src/transformers/models/bert/modeling_tf_bert.py"),
+ "transformers.models.bert.modeling_tf_bert",
+ )
+ self.assertEqual(
+ get_module_from_file("/transformers/models/gpt2/modeling_gpt2.py"),
+ "transformers.models.gpt2.modeling_gpt2",
+ )
+ with self.assertRaises(ValueError):
+ get_module_from_file("/models/gpt2/modeling_gpt2.py")
+
+ def test_duplicate_module(self):
+ bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
+ new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
+ bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+BERT_CONSTANT = "value"
+'''
+ bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = NewBertConfig
+ load_tf_weights = load_tf_weights_in_new_bert
+ base_model_prefix = "new_bert"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+NEW_BERT_CONSTANT = "value"
+'''
+ bert_expected_with_copied_from = (
+ "# Copied from transformers.bert_module.TFBertPreTrainedModel with Bert->NewBert,bert->new_bert\n"
+ + bert_expected
+ )
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ work_dir = os.path.join(tmp_dir, "transformers")
+ os.makedirs(work_dir)
+ file_name = os.path.join(work_dir, "bert_module.py")
+ dest_file_name = os.path.join(work_dir, "new_bert_module.py")
+
+ self.init_file(file_name, bert_test)
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns)
+ self.check_result(dest_file_name, bert_expected_with_copied_from)
+
+ self.init_file(file_name, bert_test)
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False)
+ self.check_result(dest_file_name, bert_expected)
+
+ def test_duplicate_module_with_copied_from(self):
+ bert_model_patterns = ModelPatterns("Bert", "bert-base-cased")
+ new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
+ bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert
+class TFBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ load_tf_weights = load_tf_weights_in_bert
+ base_model_prefix = "bert"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+BERT_CONSTANT = "value"
+'''
+ bert_expected = '''# Copied from transformers.models.xxx.XxxModel with Xxx->NewBert
+class TFNewBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = NewBertConfig
+ load_tf_weights = load_tf_weights_in_new_bert
+ base_model_prefix = "new_bert"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+
+NEW_BERT_CONSTANT = "value"
+'''
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ work_dir = os.path.join(tmp_dir, "transformers")
+ os.makedirs(work_dir)
+ file_name = os.path.join(work_dir, "bert_module.py")
+ dest_file_name = os.path.join(work_dir, "new_bert_module.py")
+
+ self.init_file(file_name, bert_test)
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns)
+ # There should not be a new Copied from statement, the old one should be adapated.
+ self.check_result(dest_file_name, bert_expected)
+
+ self.init_file(file_name, bert_test)
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False)
+ self.check_result(dest_file_name, bert_expected)
+
+ def test_filter_framework_files(self):
+ files = ["modeling_tf_bert.py", "modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
+ self.assertEqual(filter_framework_files(files), files)
+ self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files))
+
+ self.assertEqual(set(filter_framework_files(files, ["pt"])), {"modeling_bert.py", "configuration_bert.py"})
+ self.assertEqual(set(filter_framework_files(files, ["tf"])), {"modeling_tf_bert.py", "configuration_bert.py"})
+ self.assertEqual(
+ set(filter_framework_files(files, ["flax"])), {"modeling_flax_bert.py", "configuration_bert.py"}
+ )
+
+ self.assertEqual(
+ set(filter_framework_files(files, ["pt", "tf"])),
+ {"modeling_tf_bert.py", "modeling_bert.py", "configuration_bert.py"},
+ )
+ self.assertEqual(
+ set(filter_framework_files(files, ["tf", "flax"])),
+ {"modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"},
+ )
+ self.assertEqual(
+ set(filter_framework_files(files, ["pt", "flax"])),
+ {"modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"},
+ )
+
+ def test_get_model_files(self):
+ # BERT
+ bert_files = get_model_files("bert")
+
+ doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
+ self.assertEqual(model_files, BERT_MODEL_FILES)
+
+ self.assertEqual(bert_files["module_name"], "bert")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
+ bert_test_files = {
+ "tests/test_tokenization_bert.py",
+ "tests/test_modeling_bert.py",
+ "tests/test_modeling_tf_bert.py",
+ "tests/test_modeling_flax_bert.py",
+ }
+ self.assertEqual(test_files, bert_test_files)
+
+ # VIT
+ vit_files = get_model_files("vit")
+ doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
+ self.assertEqual(model_files, VIT_MODEL_FILES)
+
+ self.assertEqual(vit_files["module_name"], "vit")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
+ vit_test_files = {
+ "tests/test_feature_extraction_vit.py",
+ "tests/test_modeling_vit.py",
+ "tests/test_modeling_tf_vit.py",
+ "tests/test_modeling_flax_vit.py",
+ }
+ self.assertEqual(test_files, vit_test_files)
+
+ # Wav2Vec2
+ wav2vec2_files = get_model_files("wav2vec2")
+ doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
+ self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)
+
+ self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
+ wav2vec2_test_files = {
+ "tests/test_feature_extraction_wav2vec2.py",
+ "tests/test_modeling_wav2vec2.py",
+ "tests/test_modeling_tf_wav2vec2.py",
+ "tests/test_modeling_flax_wav2vec2.py",
+ "tests/test_processor_wav2vec2.py",
+ "tests/test_tokenization_wav2vec2.py",
+ }
+ self.assertEqual(test_files, wav2vec2_test_files)
+
+ def test_get_model_files_only_pt(self):
+ # BERT
+ bert_files = get_model_files("bert", frameworks=["pt"])
+
+ doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
+ bert_model_files = BERT_MODEL_FILES - {
+ "src/transformers/models/bert/modeling_tf_bert.py",
+ "src/transformers/models/bert/modeling_flax_bert.py",
+ }
+ self.assertEqual(model_files, bert_model_files)
+
+ self.assertEqual(bert_files["module_name"], "bert")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
+ bert_test_files = {
+ "tests/test_tokenization_bert.py",
+ "tests/test_modeling_bert.py",
+ }
+ self.assertEqual(test_files, bert_test_files)
+
+ # VIT
+ vit_files = get_model_files("vit", frameworks=["pt"])
+ doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
+ vit_model_files = VIT_MODEL_FILES - {
+ "src/transformers/models/vit/modeling_tf_vit.py",
+ "src/transformers/models/vit/modeling_flax_vit.py",
+ }
+ self.assertEqual(model_files, vit_model_files)
+
+ self.assertEqual(vit_files["module_name"], "vit")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
+ vit_test_files = {
+ "tests/test_feature_extraction_vit.py",
+ "tests/test_modeling_vit.py",
+ }
+ self.assertEqual(test_files, vit_test_files)
+
+ # Wav2Vec2
+ wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"])
+ doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
+ wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {
+ "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py",
+ "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py",
+ }
+ self.assertEqual(model_files, wav2vec2_model_files)
+
+ self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
+ wav2vec2_test_files = {
+ "tests/test_feature_extraction_wav2vec2.py",
+ "tests/test_modeling_wav2vec2.py",
+ "tests/test_processor_wav2vec2.py",
+ "tests/test_tokenization_wav2vec2.py",
+ }
+ self.assertEqual(test_files, wav2vec2_test_files)
+
+ def test_get_model_files_tf_and_flax(self):
+ # BERT
+ bert_files = get_model_files("bert", frameworks=["tf", "flax"])
+
+ doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
+ bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"}
+ self.assertEqual(model_files, bert_model_files)
+
+ self.assertEqual(bert_files["module_name"], "bert")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
+ bert_test_files = {
+ "tests/test_tokenization_bert.py",
+ "tests/test_modeling_tf_bert.py",
+ "tests/test_modeling_flax_bert.py",
+ }
+ self.assertEqual(test_files, bert_test_files)
+
+ # VIT
+ vit_files = get_model_files("vit", frameworks=["tf", "flax"])
+ doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
+ vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"}
+ self.assertEqual(model_files, vit_model_files)
+
+ self.assertEqual(vit_files["module_name"], "vit")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
+ vit_test_files = {
+ "tests/test_feature_extraction_vit.py",
+ "tests/test_modeling_tf_vit.py",
+ "tests/test_modeling_flax_vit.py",
+ }
+ self.assertEqual(test_files, vit_test_files)
+
+ # Wav2Vec2
+ wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"])
+ doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
+
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
+ wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"}
+ self.assertEqual(model_files, wav2vec2_model_files)
+
+ self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
+ wav2vec2_test_files = {
+ "tests/test_feature_extraction_wav2vec2.py",
+ "tests/test_modeling_tf_wav2vec2.py",
+ "tests/test_modeling_flax_wav2vec2.py",
+ "tests/test_processor_wav2vec2.py",
+ "tests/test_tokenization_wav2vec2.py",
+ }
+ self.assertEqual(test_files, wav2vec2_test_files)
+
+ def test_find_base_model_checkpoint(self):
+ self.assertEqual(find_base_model_checkpoint("bert"), "bert-base-uncased")
+ self.assertEqual(find_base_model_checkpoint("gpt2"), "gpt2")
+
+ def test_retrieve_model_classes(self):
+ gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()}
+ expected_gpt_classes = {
+ "pt": {"GPT2ForTokenClassification", "GPT2Model", "GPT2LMHeadModel", "GPT2ForSequenceClassification"},
+ "tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"},
+ "flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"},
+ }
+ self.assertEqual(gpt_classes, expected_gpt_classes)
+
+ del expected_gpt_classes["flax"]
+ gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["pt", "tf"]).items()}
+ self.assertEqual(gpt_classes, expected_gpt_classes)
+
+ del expected_gpt_classes["pt"]
+ gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["tf"]).items()}
+ self.assertEqual(gpt_classes, expected_gpt_classes)
+
+ def test_retrieve_info_for_model_with_bert(self):
+ bert_info = retrieve_info_for_model("bert")
+ bert_classes = [
+ "BertForTokenClassification",
+ "BertForQuestionAnswering",
+ "BertForNextSentencePrediction",
+ "BertForSequenceClassification",
+ "BertForMaskedLM",
+ "BertForMultipleChoice",
+ "BertModel",
+ "BertForPreTraining",
+ "BertLMHeadModel",
+ ]
+ expected_model_classes = {
+ "pt": set(bert_classes),
+ "tf": {f"TF{m}" for m in bert_classes},
+ "flax": {f"Flax{m}" for m in bert_classes[:-1]},
+ }
+
+ self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"})
+ model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
+ self.assertEqual(model_classes, expected_model_classes)
+
+ all_bert_files = bert_info["model_files"]
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
+ self.assertEqual(model_files, BERT_MODEL_FILES)
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
+ bert_test_files = {
+ "tests/test_tokenization_bert.py",
+ "tests/test_modeling_bert.py",
+ "tests/test_modeling_tf_bert.py",
+ "tests/test_modeling_flax_bert.py",
+ }
+ self.assertEqual(test_files, bert_test_files)
+
+ doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
+
+ self.assertEqual(all_bert_files["module_name"], "bert")
+
+ bert_model_patterns = bert_info["model_patterns"]
+ self.assertEqual(bert_model_patterns.model_name, "BERT")
+ self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
+ self.assertEqual(bert_model_patterns.model_type, "bert")
+ self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
+ self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
+ self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
+ self.assertEqual(bert_model_patterns.config_class, "BertConfig")
+ self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
+ self.assertIsNone(bert_model_patterns.feature_extractor_class)
+ self.assertIsNone(bert_model_patterns.processor_class)
+
+ def test_retrieve_info_for_model_pt_tf_with_bert(self):
+ bert_info = retrieve_info_for_model("bert", frameworks=["pt", "tf"])
+ bert_classes = [
+ "BertForTokenClassification",
+ "BertForQuestionAnswering",
+ "BertForNextSentencePrediction",
+ "BertForSequenceClassification",
+ "BertForMaskedLM",
+ "BertForMultipleChoice",
+ "BertModel",
+ "BertForPreTraining",
+ "BertLMHeadModel",
+ ]
+ expected_model_classes = {"pt": set(bert_classes), "tf": {f"TF{m}" for m in bert_classes}}
+
+ self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf"})
+ model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
+ self.assertEqual(model_classes, expected_model_classes)
+
+ all_bert_files = bert_info["model_files"]
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
+ bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_flax_bert.py"}
+ self.assertEqual(model_files, bert_model_files)
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
+ bert_test_files = {
+ "tests/test_tokenization_bert.py",
+ "tests/test_modeling_bert.py",
+ "tests/test_modeling_tf_bert.py",
+ }
+ self.assertEqual(test_files, bert_test_files)
+
+ doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/bert.mdx")
+
+ self.assertEqual(all_bert_files["module_name"], "bert")
+
+ bert_model_patterns = bert_info["model_patterns"]
+ self.assertEqual(bert_model_patterns.model_name, "BERT")
+ self.assertEqual(bert_model_patterns.checkpoint, "bert-base-uncased")
+ self.assertEqual(bert_model_patterns.model_type, "bert")
+ self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
+ self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
+ self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
+ self.assertEqual(bert_model_patterns.config_class, "BertConfig")
+ self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
+ self.assertIsNone(bert_model_patterns.feature_extractor_class)
+ self.assertIsNone(bert_model_patterns.processor_class)
+
+ def test_retrieve_info_for_model_with_vit(self):
+ vit_info = retrieve_info_for_model("vit")
+ vit_classes = ["ViTForImageClassification", "ViTModel"]
+ expected_model_classes = {
+ "pt": set(vit_classes),
+ "tf": {f"TF{m}" for m in vit_classes},
+ "flax": {f"Flax{m}" for m in vit_classes},
+ }
+
+ self.assertEqual(set(vit_info["frameworks"]), {"pt", "tf", "flax"})
+ model_classes = {k: set(v) for k, v in vit_info["model_classes"].items()}
+ self.assertEqual(model_classes, expected_model_classes)
+
+ all_vit_files = vit_info["model_files"]
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["model_files"]}
+ self.assertEqual(model_files, VIT_MODEL_FILES)
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]}
+ vit_test_files = {
+ "tests/test_feature_extraction_vit.py",
+ "tests/test_modeling_vit.py",
+ "tests/test_modeling_tf_vit.py",
+ "tests/test_modeling_flax_vit.py",
+ }
+ self.assertEqual(test_files, vit_test_files)
+
+ doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/vit.mdx")
+
+ self.assertEqual(all_vit_files["module_name"], "vit")
+
+ vit_model_patterns = vit_info["model_patterns"]
+ self.assertEqual(vit_model_patterns.model_name, "ViT")
+ self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224")
+ self.assertEqual(vit_model_patterns.model_type, "vit")
+ self.assertEqual(vit_model_patterns.model_lower_cased, "vit")
+ self.assertEqual(vit_model_patterns.model_camel_cased, "ViT")
+ self.assertEqual(vit_model_patterns.model_upper_cased, "VIT")
+ self.assertEqual(vit_model_patterns.config_class, "ViTConfig")
+ self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor")
+ self.assertIsNone(vit_model_patterns.tokenizer_class)
+ self.assertIsNone(vit_model_patterns.processor_class)
+
+ def test_retrieve_info_for_model_with_wav2vec2(self):
+ wav2vec2_info = retrieve_info_for_model("wav2vec2")
+ wav2vec2_classes = [
+ "Wav2Vec2Model",
+ "Wav2Vec2ForPreTraining",
+ "Wav2Vec2ForAudioFrameClassification",
+ "Wav2Vec2ForCTC",
+ "Wav2Vec2ForMaskedLM",
+ "Wav2Vec2ForSequenceClassification",
+ "Wav2Vec2ForXVector",
+ ]
+ expected_model_classes = {
+ "pt": set(wav2vec2_classes),
+ "tf": {f"TF{m}" for m in wav2vec2_classes[:1]},
+ "flax": {f"Flax{m}" for m in wav2vec2_classes[:2]},
+ }
+
+ self.assertEqual(set(wav2vec2_info["frameworks"]), {"pt", "tf", "flax"})
+ model_classes = {k: set(v) for k, v in wav2vec2_info["model_classes"].items()}
+ self.assertEqual(model_classes, expected_model_classes)
+
+ all_wav2vec2_files = wav2vec2_info["model_files"]
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["model_files"]}
+ self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)
+
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]}
+ wav2vec2_test_files = {
+ "tests/test_feature_extraction_wav2vec2.py",
+ "tests/test_modeling_wav2vec2.py",
+ "tests/test_modeling_tf_wav2vec2.py",
+ "tests/test_modeling_flax_wav2vec2.py",
+ "tests/test_processor_wav2vec2.py",
+ "tests/test_tokenization_wav2vec2.py",
+ }
+ self.assertEqual(test_files, wav2vec2_test_files)
+
+ doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
+ self.assertEqual(doc_file, "docs/source/model_doc/wav2vec2.mdx")
+
+ self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2")
+
+ wav2vec2_model_patterns = wav2vec2_info["model_patterns"]
+ self.assertEqual(wav2vec2_model_patterns.model_name, "Wav2Vec2")
+ self.assertEqual(wav2vec2_model_patterns.checkpoint, "facebook/wav2vec2-base-960h")
+ self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2")
+ self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2")
+ self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2")
+ self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV_2_VEC_2")
+ self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config")
+ self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor")
+ self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor")
+ self.assertEqual(wav2vec2_model_patterns.tokenizer_class, "Wav2Vec2CTCTokenizer")
+
+ def test_clean_frameworks_in_init_with_gpt(self):
+ test_init = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
+
+_import_structure = {
+ "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
+ "tokenization_gpt2": ["GPT2Tokenizer"],
+}
+
+if is_tokenizers_available():
+ _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
+
+if is_torch_available():
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
+
+if is_tf_available():
+ _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
+
+if is_flax_available():
+ _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
+
+if TYPE_CHECKING:
+ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
+ from .tokenization_gpt2 import GPT2Tokenizer
+
+ if is_tokenizers_available():
+ from .tokenization_gpt2_fast import GPT2TokenizerFast
+
+ if is_torch_available():
+ from .modeling_gpt2 import GPT2Model
+
+ if is_tf_available():
+ from .modeling_tf_gpt2 import TFGPT2Model
+
+ if is_flax_available():
+ from .modeling_flax_gpt2 import FlaxGPT2Model
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ init_no_tokenizer = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+
+_import_structure = {
+ "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
+}
+
+if is_torch_available():
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
+
+if is_tf_available():
+ _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
+
+if is_flax_available():
+ _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
+
+if TYPE_CHECKING:
+ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
+
+ if is_torch_available():
+ from .modeling_gpt2 import GPT2Model
+
+ if is_tf_available():
+ from .modeling_tf_gpt2 import TFGPT2Model
+
+ if is_flax_available():
+ from .modeling_flax_gpt2 import FlaxGPT2Model
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ init_pt_only = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
+
+_import_structure = {
+ "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
+ "tokenization_gpt2": ["GPT2Tokenizer"],
+}
+
+if is_tokenizers_available():
+ _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
+
+if is_torch_available():
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
+
+if TYPE_CHECKING:
+ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
+ from .tokenization_gpt2 import GPT2Tokenizer
+
+ if is_tokenizers_available():
+ from .tokenization_gpt2_fast import GPT2TokenizerFast
+
+ if is_torch_available():
+ from .modeling_gpt2 import GPT2Model
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ init_pt_only_no_tokenizer = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_torch_available
+
+_import_structure = {
+ "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
+}
+
+if is_torch_available():
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
+
+if TYPE_CHECKING:
+ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
+
+ if is_torch_available():
+ from .modeling_gpt2 import GPT2Model
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file_name = os.path.join(tmp_dir, "__init__.py")
+
+ self.init_file(file_name, test_init)
+ clean_frameworks_in_init(file_name, keep_processing=False)
+ self.check_result(file_name, init_no_tokenizer)
+
+ self.init_file(file_name, test_init)
+ clean_frameworks_in_init(file_name, frameworks=["pt"])
+ self.check_result(file_name, init_pt_only)
+
+ self.init_file(file_name, test_init)
+ clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False)
+ self.check_result(file_name, init_pt_only_no_tokenizer)
+
+ def test_clean_frameworks_in_init_with_vit(self):
+ test_init = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
+
+_import_structure = {
+ "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
+}
+
+if is_vision_available():
+ _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
+
+if is_torch_available():
+ _import_structure["modeling_vit"] = ["ViTModel"]
+
+if is_tf_available():
+ _import_structure["modeling_tf_vit"] = ["TFViTModel"]
+
+if is_flax_available():
+ _import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
+
+if TYPE_CHECKING:
+ from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
+
+ if is_vision_available():
+ from .feature_extraction_vit import ViTFeatureExtractor
+
+ if is_torch_available():
+ from .modeling_vit import ViTModel
+
+ if is_tf_available():
+ from .modeling_tf_vit import ViTModel
+
+ if is_flax_available():
+ from .modeling_flax_vit import ViTModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ init_no_feature_extractor = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
+
+_import_structure = {
+ "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
+}
+
+if is_torch_available():
+ _import_structure["modeling_vit"] = ["ViTModel"]
+
+if is_tf_available():
+ _import_structure["modeling_tf_vit"] = ["TFViTModel"]
+
+if is_flax_available():
+ _import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
+
+if TYPE_CHECKING:
+ from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
+
+ if is_torch_available():
+ from .modeling_vit import ViTModel
+
+ if is_tf_available():
+ from .modeling_tf_vit import ViTModel
+
+ if is_flax_available():
+ from .modeling_flax_vit import ViTModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ init_pt_only = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_torch_available, is_vision_available
+
+_import_structure = {
+ "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
+}
+
+if is_vision_available():
+ _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
+
+if is_torch_available():
+ _import_structure["modeling_vit"] = ["ViTModel"]
+
+if TYPE_CHECKING:
+ from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
+
+ if is_vision_available():
+ from .feature_extraction_vit import ViTFeatureExtractor
+
+ if is_torch_available():
+ from .modeling_vit import ViTModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ init_pt_only_no_feature_extractor = """
+from typing import TYPE_CHECKING
+
+from ...file_utils import _LazyModule, is_torch_available
+
+_import_structure = {
+ "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
+}
+
+if is_torch_available():
+ _import_structure["modeling_vit"] = ["ViTModel"]
+
+if TYPE_CHECKING:
+ from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
+
+ if is_torch_available():
+ from .modeling_vit import ViTModel
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
+"""
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file_name = os.path.join(tmp_dir, "__init__.py")
+
+ self.init_file(file_name, test_init)
+ clean_frameworks_in_init(file_name, keep_processing=False)
+ self.check_result(file_name, init_no_feature_extractor)
+
+ self.init_file(file_name, test_init)
+ clean_frameworks_in_init(file_name, frameworks=["pt"])
+ self.check_result(file_name, init_pt_only)
+
+ self.init_file(file_name, test_init)
+ clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False)
+ self.check_result(file_name, init_pt_only_no_feature_extractor)
+
+ def test_duplicate_doc_file(self):
+ test_doc = """
+# GPT2
+
+## Overview
+
+Overview of the model.
+
+## GPT2Config
+
+[[autodoc]] GPT2Config
+
+## GPT2Tokenizer
+
+[[autodoc]] GPT2Tokenizer
+ - save_vocabulary
+
+## GPT2TokenizerFast
+
+[[autodoc]] GPT2TokenizerFast
+
+## GPT2 specific outputs
+
+[[autodoc]] models.gpt2.modeling_gpt2.GPT2DoubleHeadsModelOutput
+
+[[autodoc]] models.gpt2.modeling_tf_gpt2.TFGPT2DoubleHeadsModelOutput
+
+## GPT2Model
+
+[[autodoc]] GPT2Model
+ - forward
+
+## TFGPT2Model
+
+[[autodoc]] TFGPT2Model
+ - call
+
+## FlaxGPT2Model
+
+[[autodoc]] FlaxGPT2Model
+ - __call__
+
+"""
+ test_new_doc = """
+# GPT-New New
+
+## Overview
+
+The GPT-New New model was proposed in [() by .
+
+
+The abstract from the paper is the following:
+
+**
+
+Tips:
+
+
+
+This model was contributed by [INSERT YOUR HF USERNAME HERE]().
+The original code can be found [here]().
+
+
+## GPTNewNewConfig
+
+[[autodoc]] GPTNewNewConfig
+
+## GPTNewNewTokenizer
+
+[[autodoc]] GPTNewNewTokenizer
+ - save_vocabulary
+
+## GPTNewNewTokenizerFast
+
+[[autodoc]] GPTNewNewTokenizerFast
+
+## GPTNewNew specific outputs
+
+[[autodoc]] models.gpt_new_new.modeling_gpt_new_new.GPTNewNewDoubleHeadsModelOutput
+
+[[autodoc]] models.gpt_new_new.modeling_tf_gpt_new_new.TFGPTNewNewDoubleHeadsModelOutput
+
+## GPTNewNewModel
+
+[[autodoc]] GPTNewNewModel
+ - forward
+
+## TFGPTNewNewModel
+
+[[autodoc]] TFGPTNewNewModel
+ - call
+
+## FlaxGPTNewNewModel
+
+[[autodoc]] FlaxGPTNewNewModel
+ - __call__
+
+"""
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ doc_file = os.path.join(tmp_dir, "gpt2.mdx")
+ new_doc_file = os.path.join(tmp_dir, "gpt-new-new.mdx")
+
+ gpt2_model_patterns = ModelPatterns("GPT2", "gpt2", tokenizer_class="GPT2Tokenizer")
+ new_model_patterns = ModelPatterns(
+ "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPTNewNewTokenizer"
+ )
+
+ self.init_file(doc_file, test_doc)
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns)
+ self.check_result(new_doc_file, test_new_doc)
+
+ test_new_doc_pt_only = test_new_doc.replace(
+ """
+## TFGPTNewNewModel
+
+[[autodoc]] TFGPTNewNewModel
+ - call
+
+## FlaxGPTNewNewModel
+
+[[autodoc]] FlaxGPTNewNewModel
+ - __call__
+
+""",
+ "",
+ )
+ self.init_file(doc_file, test_doc)
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
+ self.check_result(new_doc_file, test_new_doc_pt_only)
+
+ test_new_doc_no_tok = test_new_doc.replace(
+ """
+## GPTNewNewTokenizer
+
+[[autodoc]] GPTNewNewTokenizer
+ - save_vocabulary
+
+## GPTNewNewTokenizerFast
+
+[[autodoc]] GPTNewNewTokenizerFast
+""",
+ "",
+ )
+ new_model_patterns = ModelPatterns(
+ "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPT2Tokenizer"
+ )
+ self.init_file(doc_file, test_doc)
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns)
+ print(test_new_doc_no_tok)
+ self.check_result(new_doc_file, test_new_doc_no_tok)
+
+ test_new_doc_pt_only_no_tok = test_new_doc_no_tok.replace(
+ """
+## TFGPTNewNewModel
+
+[[autodoc]] TFGPTNewNewModel
+ - call
+
+## FlaxGPTNewNewModel
+
+[[autodoc]] FlaxGPTNewNewModel
+ - __call__
+
+""",
+ "",
+ )
+ self.init_file(doc_file, test_doc)
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
+ self.check_result(new_doc_file, test_new_doc_pt_only_no_tok)
diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py
index 358d638a327..1036981a628 100644
--- a/utils/tests_fetcher.py
+++ b/utils/tests_fetcher.py
@@ -243,6 +243,7 @@ def create_reverse_dependency_map():
# Any module file that has a test name which can't be inferred automatically from its name should go here. A better
# approach is to (re-)name the test file accordingly, and second best to add the correspondence map here.
SPECIAL_MODULE_TO_TEST_MAP = {
+ "commands/add_new_model_like.py": "test_add_new_model_like.py",
"configuration_utils.py": "test_configuration_common.py",
"convert_graph_to_onnx.py": "test_onnx.py",
"data/data_collator.py": "test_data_collator.py",