mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Update tiny model creation script (#22202)
* Update UNCONVERTIBLE_MODEL_ARCHITECTURES * Deal with 2 model tester classes in single test file * Deal with 2 model tester classes in single test file * Deal with 2 model tester classes in single test file * make style and quality --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
464d420775
commit
4c5c0af7e5
@ -16,18 +16,17 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import copy
|
import copy
|
||||||
import importlib
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from check_config_docstrings import get_checkpoint_from_config_class
|
from check_config_docstrings import get_checkpoint_from_config_class
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from get_test_info import get_model_to_tester_mapping, get_tester_classes_for_model
|
||||||
from huggingface_hub import Repository, create_repo, upload_folder
|
from huggingface_hub import Repository, create_repo, upload_folder
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -58,7 +57,6 @@ logging.set_verbosity_error()
|
|||||||
logging.disable_progress_bar()
|
logging.disable_progress_bar()
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
sys.path.append(".")
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
@ -67,6 +65,7 @@ if not is_torch_available():
|
|||||||
if not is_tf_available():
|
if not is_tf_available():
|
||||||
raise ValueError("Please install TensorFlow.")
|
raise ValueError("Please install TensorFlow.")
|
||||||
|
|
||||||
|
|
||||||
FRAMEWORKS = ["pytorch", "tensorflow"]
|
FRAMEWORKS = ["pytorch", "tensorflow"]
|
||||||
INVALID_ARCH = []
|
INVALID_ARCH = []
|
||||||
TARGET_VOCAB_SIZE = 1024
|
TARGET_VOCAB_SIZE = 1024
|
||||||
@ -94,8 +93,12 @@ UNCONVERTIBLE_MODEL_ARCHITECTURES = {
|
|||||||
"TFCamembertModel",
|
"TFCamembertModel",
|
||||||
"TFCamembertForCausalLM",
|
"TFCamembertForCausalLM",
|
||||||
"DecisionTransformerModel",
|
"DecisionTransformerModel",
|
||||||
|
"GraphormerModel",
|
||||||
|
"InformerModel",
|
||||||
"JukeboxModel",
|
"JukeboxModel",
|
||||||
"MarianForCausalLM",
|
"MarianForCausalLM",
|
||||||
|
"MaskFormerSwinModel",
|
||||||
|
"MaskFormerSwinBackbone",
|
||||||
"MT5Model",
|
"MT5Model",
|
||||||
"MT5ForConditionalGeneration",
|
"MT5ForConditionalGeneration",
|
||||||
"TFMT5ForConditionalGeneration",
|
"TFMT5ForConditionalGeneration",
|
||||||
@ -126,6 +129,7 @@ UNCONVERTIBLE_MODEL_ARCHITECTURES = {
|
|||||||
"XLMRobertaForQuestionAnswering",
|
"XLMRobertaForQuestionAnswering",
|
||||||
"TFXLMRobertaForSequenceClassification",
|
"TFXLMRobertaForSequenceClassification",
|
||||||
"TFXLMRobertaForMaskedLM",
|
"TFXLMRobertaForMaskedLM",
|
||||||
|
"TFXLMRobertaForCausalLM",
|
||||||
"TFXLMRobertaForQuestionAnswering",
|
"TFXLMRobertaForQuestionAnswering",
|
||||||
"TFXLMRobertaModel",
|
"TFXLMRobertaModel",
|
||||||
"TFXLMRobertaForMultipleChoice",
|
"TFXLMRobertaForMultipleChoice",
|
||||||
@ -355,7 +359,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
|
|||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
def get_tiny_config(config_class, **model_tester_kwargs):
|
def get_tiny_config(config_class, model_class=None, **model_tester_kwargs):
|
||||||
"""Retrieve a tiny configuration from `config_class` using each model's `ModelTester`.
|
"""Retrieve a tiny configuration from `config_class` using each model's `ModelTester`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -378,9 +382,18 @@ def get_tiny_config(config_class, **model_tester_kwargs):
|
|||||||
module_name = model_type_to_module_name(model_type)
|
module_name = model_type_to_module_name(model_type)
|
||||||
if not modeling_name.startswith(module_name):
|
if not modeling_name.startswith(module_name):
|
||||||
raise ValueError(f"{modeling_name} doesn't start with {module_name}!")
|
raise ValueError(f"{modeling_name} doesn't start with {module_name}!")
|
||||||
module = importlib.import_module(f".models.{module_name}.test_modeling_{modeling_name}", package="tests")
|
test_file = os.path.join("tests", "models", module_name, f"test_modeling_{modeling_name}.py")
|
||||||
camel_case_model_name = config_class.__name__.split("Config")[0]
|
models_to_model_testers = get_model_to_tester_mapping(test_file)
|
||||||
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
|
# Find the model tester class
|
||||||
|
model_tester_class = None
|
||||||
|
tester_classes = []
|
||||||
|
if model_class is not None:
|
||||||
|
tester_classes = get_tester_classes_for_model(test_file, model_class)
|
||||||
|
else:
|
||||||
|
for _tester_classes in models_to_model_testers.values():
|
||||||
|
tester_classes.extend(_tester_classes)
|
||||||
|
if len(tester_classes) > 0:
|
||||||
|
model_tester_class = sorted(tester_classes, key=lambda x: x.__name__)[0]
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name."
|
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name."
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
Loading…
Reference in New Issue
Block a user