mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Update tiny model creation script (#22202)
* Update UNCONVERTIBLE_MODEL_ARCHITECTURES * Deal with 2 model tester classes in single test file * Deal with 2 model tester classes in single test file * Deal with 2 model tester classes in single test file * make style and quality --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
464d420775
commit
4c5c0af7e5
@ -16,18 +16,17 @@
|
||||
import argparse
|
||||
import collections.abc
|
||||
import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from check_config_docstrings import get_checkpoint_from_config_class
|
||||
from datasets import load_dataset
|
||||
from get_test_info import get_model_to_tester_mapping, get_tester_classes_for_model
|
||||
from huggingface_hub import Repository, create_repo, upload_folder
|
||||
|
||||
from transformers import (
|
||||
@ -58,7 +57,6 @@ logging.set_verbosity_error()
|
||||
logging.disable_progress_bar()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
sys.path.append(".")
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
if not is_torch_available():
|
||||
@ -67,6 +65,7 @@ if not is_torch_available():
|
||||
if not is_tf_available():
|
||||
raise ValueError("Please install TensorFlow.")
|
||||
|
||||
|
||||
FRAMEWORKS = ["pytorch", "tensorflow"]
|
||||
INVALID_ARCH = []
|
||||
TARGET_VOCAB_SIZE = 1024
|
||||
@ -94,8 +93,12 @@ UNCONVERTIBLE_MODEL_ARCHITECTURES = {
|
||||
"TFCamembertModel",
|
||||
"TFCamembertForCausalLM",
|
||||
"DecisionTransformerModel",
|
||||
"GraphormerModel",
|
||||
"InformerModel",
|
||||
"JukeboxModel",
|
||||
"MarianForCausalLM",
|
||||
"MaskFormerSwinModel",
|
||||
"MaskFormerSwinBackbone",
|
||||
"MT5Model",
|
||||
"MT5ForConditionalGeneration",
|
||||
"TFMT5ForConditionalGeneration",
|
||||
@ -126,6 +129,7 @@ UNCONVERTIBLE_MODEL_ARCHITECTURES = {
|
||||
"XLMRobertaForQuestionAnswering",
|
||||
"TFXLMRobertaForSequenceClassification",
|
||||
"TFXLMRobertaForMaskedLM",
|
||||
"TFXLMRobertaForCausalLM",
|
||||
"TFXLMRobertaForQuestionAnswering",
|
||||
"TFXLMRobertaModel",
|
||||
"TFXLMRobertaForMultipleChoice",
|
||||
@ -355,7 +359,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
|
||||
return processor
|
||||
|
||||
|
||||
def get_tiny_config(config_class, **model_tester_kwargs):
|
||||
def get_tiny_config(config_class, model_class=None, **model_tester_kwargs):
|
||||
"""Retrieve a tiny configuration from `config_class` using each model's `ModelTester`.
|
||||
|
||||
Args:
|
||||
@ -378,9 +382,18 @@ def get_tiny_config(config_class, **model_tester_kwargs):
|
||||
module_name = model_type_to_module_name(model_type)
|
||||
if not modeling_name.startswith(module_name):
|
||||
raise ValueError(f"{modeling_name} doesn't start with {module_name}!")
|
||||
module = importlib.import_module(f".models.{module_name}.test_modeling_{modeling_name}", package="tests")
|
||||
camel_case_model_name = config_class.__name__.split("Config")[0]
|
||||
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
|
||||
test_file = os.path.join("tests", "models", module_name, f"test_modeling_{modeling_name}.py")
|
||||
models_to_model_testers = get_model_to_tester_mapping(test_file)
|
||||
# Find the model tester class
|
||||
model_tester_class = None
|
||||
tester_classes = []
|
||||
if model_class is not None:
|
||||
tester_classes = get_tester_classes_for_model(test_file, model_class)
|
||||
else:
|
||||
for _tester_classes in models_to_model_testers.values():
|
||||
tester_classes.extend(_tester_classes)
|
||||
if len(tester_classes) > 0:
|
||||
model_tester_class = sorted(tester_classes, key=lambda x: x.__name__)[0]
|
||||
except ModuleNotFoundError:
|
||||
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name."
|
||||
raise ValueError(error)
|
||||
|
Loading…
Reference in New Issue
Block a user