Update tiny model creation script (#22202)

* Update UNCONVERTIBLE_MODEL_ARCHITECTURES

* Deal with 2 model tester classes in single test file

* Deal with 2 model tester classes in single test file

* Deal with 2 model tester classes in single test file

* make style and quality

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-03-16 14:21:58 +01:00 committed by GitHub
parent 464d420775
commit 4c5c0af7e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,18 +16,17 @@
import argparse import argparse
import collections.abc import collections.abc
import copy import copy
import importlib
import inspect import inspect
import json import json
import os import os
import shutil import shutil
import sys
import tempfile import tempfile
import traceback import traceback
from pathlib import Path from pathlib import Path
from check_config_docstrings import get_checkpoint_from_config_class from check_config_docstrings import get_checkpoint_from_config_class
from datasets import load_dataset from datasets import load_dataset
from get_test_info import get_model_to_tester_mapping, get_tester_classes_for_model
from huggingface_hub import Repository, create_repo, upload_folder from huggingface_hub import Repository, create_repo, upload_folder
from transformers import ( from transformers import (
@ -58,7 +57,6 @@ logging.set_verbosity_error()
logging.disable_progress_bar() logging.disable_progress_bar()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
sys.path.append(".")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if not is_torch_available(): if not is_torch_available():
@ -67,6 +65,7 @@ if not is_torch_available():
if not is_tf_available(): if not is_tf_available():
raise ValueError("Please install TensorFlow.") raise ValueError("Please install TensorFlow.")
FRAMEWORKS = ["pytorch", "tensorflow"] FRAMEWORKS = ["pytorch", "tensorflow"]
INVALID_ARCH = [] INVALID_ARCH = []
TARGET_VOCAB_SIZE = 1024 TARGET_VOCAB_SIZE = 1024
@ -94,8 +93,12 @@ UNCONVERTIBLE_MODEL_ARCHITECTURES = {
"TFCamembertModel", "TFCamembertModel",
"TFCamembertForCausalLM", "TFCamembertForCausalLM",
"DecisionTransformerModel", "DecisionTransformerModel",
"GraphormerModel",
"InformerModel",
"JukeboxModel", "JukeboxModel",
"MarianForCausalLM", "MarianForCausalLM",
"MaskFormerSwinModel",
"MaskFormerSwinBackbone",
"MT5Model", "MT5Model",
"MT5ForConditionalGeneration", "MT5ForConditionalGeneration",
"TFMT5ForConditionalGeneration", "TFMT5ForConditionalGeneration",
@ -126,6 +129,7 @@ UNCONVERTIBLE_MODEL_ARCHITECTURES = {
"XLMRobertaForQuestionAnswering", "XLMRobertaForQuestionAnswering",
"TFXLMRobertaForSequenceClassification", "TFXLMRobertaForSequenceClassification",
"TFXLMRobertaForMaskedLM", "TFXLMRobertaForMaskedLM",
"TFXLMRobertaForCausalLM",
"TFXLMRobertaForQuestionAnswering", "TFXLMRobertaForQuestionAnswering",
"TFXLMRobertaModel", "TFXLMRobertaModel",
"TFXLMRobertaForMultipleChoice", "TFXLMRobertaForMultipleChoice",
@ -355,7 +359,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
return processor return processor
def get_tiny_config(config_class, **model_tester_kwargs): def get_tiny_config(config_class, model_class=None, **model_tester_kwargs):
"""Retrieve a tiny configuration from `config_class` using each model's `ModelTester`. """Retrieve a tiny configuration from `config_class` using each model's `ModelTester`.
Args: Args:
@ -378,9 +382,18 @@ def get_tiny_config(config_class, **model_tester_kwargs):
module_name = model_type_to_module_name(model_type) module_name = model_type_to_module_name(model_type)
if not modeling_name.startswith(module_name): if not modeling_name.startswith(module_name):
raise ValueError(f"{modeling_name} doesn't start with {module_name}!") raise ValueError(f"{modeling_name} doesn't start with {module_name}!")
module = importlib.import_module(f".models.{module_name}.test_modeling_{modeling_name}", package="tests") test_file = os.path.join("tests", "models", module_name, f"test_modeling_{modeling_name}.py")
camel_case_model_name = config_class.__name__.split("Config")[0] models_to_model_testers = get_model_to_tester_mapping(test_file)
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None) # Find the model tester class
model_tester_class = None
tester_classes = []
if model_class is not None:
tester_classes = get_tester_classes_for_model(test_file, model_class)
else:
for _tester_classes in models_to_model_testers.values():
tester_classes.extend(_tester_classes)
if len(tester_classes) > 0:
model_tester_class = sorted(tester_classes, key=lambda x: x.__name__)[0]
except ModuleNotFoundError: except ModuleNotFoundError:
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name." error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name."
raise ValueError(error) raise ValueError(error)