Update tiny model creation script (#22058)

Update the script

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-03-09 19:53:54 +01:00 committed by GitHub
parent 7a2b915e92
commit 6d9031f285
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,7 @@ import os
import shutil
import sys
import tempfile
import traceback
from pathlib import Path
from check_config_docstrings import get_checkpoint_from_config_class
@ -71,6 +72,67 @@ INVALID_ARCH = []
TARGET_VOCAB_SIZE = 1024
# This list contains the model architectures for which a tiny version could not be created.
# Avoid to add new architectures here - unless we have verified carefully that it's (almost) impossible to create them.
# One such case is: no model tester class is implemented for a model type (like `MT5`) because its architecture is
# identical to another one (`MT5` is based on `T5`), but trained on different datasets or with different techniques.
UNCONVERTIBLE_MODEL_ARCHITECTURES = {
"BertGenerationEncoder",
"BertGenerationDecoder",
"CamembertForSequenceClassification",
"CamembertForMultipleChoice",
"CamembertForMaskedLM",
"CamembertForCausalLM",
"CamembertForTokenClassification",
"CamembertForQuestionAnswering",
"CamembertModel",
"TFCamembertForMultipleChoice",
"TFCamembertForTokenClassification",
"TFCamembertForQuestionAnswering",
"TFCamembertForSequenceClassification",
"TFCamembertForMaskedLM",
"TFCamembertModel",
"TFCamembertForCausalLM",
"DecisionTransformerModel",
"JukeboxModel",
"MarianForCausalLM",
"MT5Model",
"MT5ForConditionalGeneration",
"TFMT5ForConditionalGeneration",
"TFMT5Model",
"QDQBertForSequenceClassification",
"QDQBertForMaskedLM",
"QDQBertModel",
"QDQBertForTokenClassification",
"QDQBertLMHeadModel",
"QDQBertForMultipleChoice",
"QDQBertForQuestionAnswering",
"QDQBertForNextSentencePrediction",
"ReformerModelWithLMHead",
"RetriBertModel",
"Speech2Text2ForCausalLM",
"TimeSeriesTransformerModel",
"TrajectoryTransformerModel",
"TrOCRForCausalLM",
"XLMProphetNetForConditionalGeneration",
"XLMProphetNetForCausalLM",
"XLMProphetNetModel",
"XLMRobertaModel",
"XLMRobertaForTokenClassification",
"XLMRobertaForMultipleChoice",
"XLMRobertaForMaskedLM",
"XLMRobertaForCausalLM",
"XLMRobertaForSequenceClassification",
"XLMRobertaForQuestionAnswering",
"TFXLMRobertaForSequenceClassification",
"TFXLMRobertaForMaskedLM",
"TFXLMRobertaForQuestionAnswering",
"TFXLMRobertaModel",
"TFXLMRobertaForMultipleChoice",
"TFXLMRobertaForTokenClassification",
}
def get_processor_types_from_config_class(config_class, allowed_mappings=None):
"""Return a tuple of processors for `config_class`.
@ -131,7 +193,7 @@ def get_architectures_from_config_class(config_class, arch_mappings):
models = mapping[config_class]
models = tuple(models) if isinstance(models, collections.abc.Sequence) else (models,)
for model in models:
if model.__name__ not in unexportable_model_architectures:
if model.__name__ not in UNCONVERTIBLE_MODEL_ARCHITECTURES:
architectures.add(model)
architectures = tuple(architectures)
@ -186,8 +248,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
try:
processor = processor_class.from_pretrained(checkpoint)
except Exception as e:
logger.error(e)
pass
logger.error(f"{e.__class__.__name__}: {e}")
# Try to get a new processor class from checkpoint. This is helpful for a checkpoint without necessary file to load
# processor while `processor_class` is an Auto class. For example, `sew` has `Wav2Vec2Processor` in
@ -203,7 +264,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
try:
config = AutoConfig.from_pretrained(checkpoint)
except Exception as e:
logger.error(e)
logger.error(f"{e.__class__.__name__}: {e}")
config = None
if config is not None:
if not isinstance(config, config_class):
@ -263,8 +324,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
try:
processor = processor_class(**{k: v[0] for k, v in attrs.items()})
except Exception as e:
logger.error(e)
pass
logger.error(f"{e.__class__.__name__}: {e}")
else:
# `checkpoint` might lack some file(s) to load a processor. For example, `facebook/hubert-base-ls960`
# has no tokenizer file to load `Wav2Vec2CTCTokenizer`. In this case, we try to build a processor
@ -282,8 +342,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
try:
processor = processor_class()
except Exception as e:
logger.error(e)
pass
logger.error(f"{e.__class__.__name__}: {e}")
# validation
if processor is not None:
@ -322,12 +381,12 @@ def get_tiny_config(config_class, **model_tester_kwargs):
module = importlib.import_module(f".models.{module_name}.test_modeling_{modeling_name}", package="tests")
camel_case_model_name = config_class.__name__.split("Config")[0]
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
except ModuleNotFoundError as e:
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name"
raise ValueError(f"{error}: {e}")
except ModuleNotFoundError:
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name."
raise ValueError(error)
if model_tester_class is None:
error = f"Tiny config not created for {model_type} - no model tester is found in the testing module"
error = f"Tiny config not created for {model_type} - no model tester is found in the testing module."
raise ValueError(error)
# `parent` is an instance of `unittest.TestCase`, but we don't need it here.
@ -434,9 +493,12 @@ def convert_processors(processors, tiny_config, output_folder, result):
# be retrained
if fast_tokenizer.vocab_size > TARGET_VOCAB_SIZE:
fast_tokenizer = convert_tokenizer(tokenizer)
except Exception as e:
except Exception:
result["warnings"].append(
f"Failed to convert the fast tokenizer for {fast_tokenizer.__class__.__name__}: {e}"
(
f"Failed to convert the fast tokenizer for {fast_tokenizer.__class__.__name__}.",
traceback.format_exc(),
)
)
continue
elif slow_tokenizer is None:
@ -446,9 +508,12 @@ def convert_processors(processors, tiny_config, output_folder, result):
if fast_tokenizer:
try:
fast_tokenizer.save_pretrained(output_folder)
except Exception as e:
except Exception:
result["warnings"].append(
f"Failed to save the fast tokenizer for {fast_tokenizer.__class__.__name__}: {e}"
(
f"Failed to save the fast tokenizer for {fast_tokenizer.__class__.__name__}.",
traceback.format_exc(),
)
)
fast_tokenizer = None
@ -456,9 +521,12 @@ def convert_processors(processors, tiny_config, output_folder, result):
if fast_tokenizer:
try:
slow_tokenizer = AutoTokenizer.from_pretrained(output_folder, use_fast=False)
except Exception as e:
except Exception:
result["warnings"].append(
f"Failed to load the slow tokenizer saved from {fast_tokenizer.__class__.__name__}: {e}"
(
f"Failed to load the slow tokenizer saved from {fast_tokenizer.__class__.__name__}.",
traceback.format_exc(),
)
)
# Let's just keep the fast version
slow_tokenizer = None
@ -467,17 +535,25 @@ def convert_processors(processors, tiny_config, output_folder, result):
if not fast_tokenizer and slow_tokenizer:
try:
slow_tokenizer.save_pretrained(output_folder)
except Exception as e:
except Exception:
result["warnings"].append(
f"Failed to save the slow tokenizer for {slow_tokenizer.__class__.__name__}: {e}"
(
f"Failed to save the slow tokenizer for {slow_tokenizer.__class__.__name__}.",
traceback.format_exc(),
)
)
slow_tokenizer = None
# update feature extractors using the tiny config
try:
feature_extractors = [convert_feature_extractor(p, tiny_config) for p in feature_extractors]
except Exception as e:
result["warnings"].append(f"Failed to convert feature extractors: {e}")
except Exception:
result["warnings"].append(
(
"Failed to convert feature extractors.",
traceback.format_exc(),
)
)
feature_extractors = []
if hasattr(tiny_config, "max_position_embeddings") and tiny_config.max_position_embeddings > 0:
@ -538,9 +614,9 @@ def build_model(model_arch, tiny_config, output_dir):
return model
def fill_result_with_error(result, error, models_to_create):
def fill_result_with_error(result, error, trace, models_to_create):
"""Fill `result` with errors for all target model arch if we can't build processor"""
error = (error, trace)
result["error"] = error
for framework in FRAMEWORKS:
if framework in models_to_create:
@ -548,7 +624,7 @@ def fill_result_with_error(result, error, models_to_create):
for model_arch in models_to_create[framework]:
result[framework][model_arch.__name__] = {"model": None, "checkpoint": None, "error": error}
result["processor"] = {type(p).__name__: p.__class__.__name__ for p in result["processor"]}
result["processor"] = {p.__class__.__name__: p.__class__.__name__ for p in result["processor"].values()}
def upload_model(model_dir, organization):
@ -572,7 +648,7 @@ def upload_model(model_dir, organization):
except Exception as e:
error = e
if error is not None:
raise ValueError(error)
raise error
with tempfile.TemporaryDirectory() as tmpdir:
repo = Repository(local_dir=tmpdir, clone_from=f"{organization}/{repo_name}")
@ -589,13 +665,13 @@ def upload_model(model_dir, organization):
commit_description=f"Upload tiny models for {arch_name}",
create_pr=True,
)
logger.warning(f"PR open in {hub_pr_url}")
logger.warning(f"PR open in {hub_pr_url}.")
else:
# Push to Hub repo directly
repo.git_add(auto_lfs_track=True)
repo.git_commit(f"Upload tiny models for {arch_name}")
repo.git_push(blocking=True) # this prints a progress bar with the upload
logger.warning(f"Tiny models {arch_name} pushed to {organization}/{repo_name}")
logger.warning(f"Tiny models {arch_name} pushed to {organization}/{repo_name}.")
def build_composite_models(config_class, output_dir):
@ -715,7 +791,7 @@ def build_composite_models(config_class, output_dir):
shutil.copytree(decoder_processor_path, model_path, dirs_exist_ok=True)
# fill `result`
result["processor"] = tuple({x.__name__ for x in encoder_processor + decoder_processor})
result["processor"] = {x.__name__: x.__name__ for x in encoder_processor + decoder_processor}
result["pytorch"] = {model_class.__name__: {"model": model_class.__name__, "checkpoint": model_path}}
@ -724,9 +800,11 @@ def build_composite_models(config_class, output_dir):
result["tensorflow"] = {
tf_model_class.__name__: {"model": tf_model_class.__name__, "checkpoint": model_path}
}
except Exception as e:
result["error"] = f"Failed to build models for {config_class.__name__}: {e}"
except Exception:
result["error"] = (
f"Failed to build models for {config_class.__name__}.",
traceback.format_exc(),
)
if not result["error"]:
del result["error"]
@ -862,8 +940,8 @@ def build(config_class, models_to_create, output_dir):
if len(processor_classes) == 0:
error = f"No processor class could be found in {config_class.__name__}."
fill_result_with_error(result, error, models_to_create)
logger.error(result["error"])
fill_result_with_error(result, error, None, models_to_create)
logger.error(result["error"][0])
return result
for processor_class in processor_classes:
@ -871,24 +949,26 @@ def build(config_class, models_to_create, output_dir):
processor = build_processor(config_class, processor_class, allow_no_checkpoint=True)
if processor is not None:
result["processor"][processor_class] = processor
except Exception as e:
error = f"Failed to build processor for {processor_class.__name__}: {e}"
fill_result_with_error(result, error, models_to_create)
logger.error(result["error"])
except Exception:
error = f"Failed to build processor for {processor_class.__name__}."
trace = traceback.format_exc()
fill_result_with_error(result, error, trace, models_to_create)
logger.error(result["error"][0])
return result
if len(result["processor"]) == 0:
error = f"No processor could be built for {config_class.__name__}."
fill_result_with_error(result, error, models_to_create)
logger.error(result["error"])
fill_result_with_error(result, error, None, models_to_create)
logger.error(result["error"][0])
return result
try:
tiny_config = get_tiny_config(config_class)
except Exception as e:
error = f"Failed to get tiny config for {config_class.__name__}: {e}"
fill_result_with_error(result, error, models_to_create)
logger.error(result["error"])
trace = traceback.format_exc()
fill_result_with_error(result, error, trace, models_to_create)
logger.error(result["error"][0])
return result
# Convert the processors (reduce vocabulary size, smaller image size, etc.)
@ -896,22 +976,24 @@ def build(config_class, models_to_create, output_dir):
processor_output_folder = os.path.join(output_dir, "processors")
try:
processors = convert_processors(processors, tiny_config, processor_output_folder, result)
except Exception as e:
error = f"Failed to convert the processors: {e}"
result["warnings"].append(error)
except Exception:
error = "Failed to convert the processors."
trace = traceback.format_exc()
result["warnings"].append((error, trace))
if len(processors) == 0:
error = f"No processor is returned by `convert_processors` for {config_class.__name__}."
fill_result_with_error(result, error, models_to_create)
logger.error(result["error"])
fill_result_with_error(result, error, None, models_to_create)
logger.error(result["error"][0])
return result
try:
config_overrides = get_config_overrides(config_class, processors)
except Exception as e:
error = f"Failure occurs while calling `get_config_overrides`: {e}"
fill_result_with_error(result, error, models_to_create)
logger.error(result["error"])
trace = traceback.format_exc()
fill_result_with_error(result, error, trace, models_to_create)
logger.error(result["error"][0])
return result
# Just for us to see this easily in the report
@ -935,7 +1017,7 @@ def build(config_class, models_to_create, output_dir):
tiny_config.text_config_dict[k] = v
if result["warnings"]:
logger.warning(result["warnings"])
logger.warning(result["warnings"][0][0])
# update `result["processor"]`
result["processor"] = {type(p).__name__: p.__class__.__name__ for p in processors}
@ -948,13 +1030,14 @@ def build(config_class, models_to_create, output_dir):
except Exception as e:
model = None
error = f"Failed to create the pytorch model for {pytorch_arch}: {e}"
trace = traceback.format_exc()
result["pytorch"][pytorch_arch.__name__]["model"] = model.__class__.__name__ if model is not None else None
result["pytorch"][pytorch_arch.__name__]["checkpoint"] = (
get_checkpoint_dir(output_dir, pytorch_arch) if model is not None else None
)
if error is not None:
result["pytorch"][pytorch_arch.__name__]["error"] = error
result["pytorch"][pytorch_arch.__name__]["error"] = (error, trace)
logger.error(f"{pytorch_arch.__name__}: {error}")
for tensorflow_arch in models_to_create["tensorflow"]:
@ -974,12 +1057,14 @@ def build(config_class, models_to_create, output_dir):
# Conversion may fail. Let's not create a model with different weights to avoid confusion (for now).
model = None
error = f"Failed to convert the pytorch model to the tensorflow model for {pt_arch}: {e}"
trace = traceback.format_exc()
else:
try:
model = build_model(tensorflow_arch, tiny_config, output_dir=output_dir)
except Exception as e:
model = None
error = f"Failed to create the tensorflow model for {tensorflow_arch}: {e}"
trace = traceback.format_exc()
result["tensorflow"][tensorflow_arch.__name__]["model"] = (
model.__class__.__name__ if model is not None else None
@ -988,7 +1073,7 @@ def build(config_class, models_to_create, output_dir):
get_checkpoint_dir(output_dir, tensorflow_arch) if model is not None else None
)
if error is not None:
result["tensorflow"][tensorflow_arch.__name__]["error"] = error
result["tensorflow"][tensorflow_arch.__name__]["error"] = (error, trace)
logger.error(f"{tensorflow_arch.__name__}: {error}")
if not result["error"]:
@ -999,6 +1084,37 @@ def build(config_class, models_to_create, output_dir):
return result
def build_tiny_model_summary(results):
"""Build a summary: a dictionary of the form
{
model architecture name:
{
"tokenizer_classes": [...],
"processor_classes": [...]
}
..
}
"""
tiny_model_summary = {}
for config_name in results:
processors = [key for key, value in results[config_name]["processor"].items()]
tokenizer_classes = [x for x in processors if x.endswith("TokenizerFast") or x.endswith("Tokenizer")]
processor_classes = [x for x in processors if x not in tokenizer_classes]
for framework in FRAMEWORKS:
if framework not in results[config_name]:
continue
for arch_name in results[config_name][framework]:
# tiny model is not created for `arch_name`
if results[config_name][framework][arch_name] is None:
continue
tiny_model_summary[arch_name] = {
"tokenizer_classes": tokenizer_classes,
"processor_classes": processor_classes,
}
return tiny_model_summary
def build_failed_report(results, include_warning=True):
failed_results = {}
for config_name in results:
@ -1039,10 +1155,10 @@ def build_simple_report(results):
for arch_name in results[config_name][framework]:
if "error" in results[config_name][framework][arch_name]:
result = results[config_name][framework][arch_name]["error"]
failed_text += f"{arch_name}: {result}\n"
failed_text += f"{arch_name}: {result[0]}\n"
else:
result = "OK"
text += f"{arch_name}: {result}\n"
result = ("OK",)
text += f"{arch_name}: {result[0]}\n"
return text, failed_text
@ -1066,8 +1182,6 @@ if __name__ == "__main__":
tensorflow_arch_mappings = [getattr(transformers_module, x) for x in _tensorflow_arch_mappings]
# flax_arch_mappings = [getattr(transformers_module, x) for x in _flax_arch_mappings]
unexportable_model_architectures = []
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
training_ds = ds["train"]
testing_ds = ds["test"]
@ -1129,16 +1243,27 @@ if __name__ == "__main__":
with open("tiny_model_creation_report.json", "w") as fp:
json.dump(results, fp, indent=4)
# Build the failure report
# Build the tiny model summary file. The `tokenizer_classes` and `processor_classes` could be both empty lists.
# When using the items in this file to update the file `tests/utils/tiny_model_summary.json`, the model
# architectures with `tokenizer_classes` and `processor_classes` being both empty should **NOT** be added to
# `tests/utils/tiny_model_summary.json`.
tiny_model_summary = build_tiny_model_summary(results)
with open("tiny_model_summary.json", "w") as fp:
json.dump(tiny_model_summary, fp, indent=4)
# Build the warning/failure report (json format): same format as the complete `results` except this contains only
# warnings or errors.
failed_results = build_failed_report(results)
with open("failed_report.json", "w") as fp:
json.dump(failed_results, fp, indent=4)
# Build the failure report
simple_report, failed_report = build_simple_report(results)
# The simplified report: a .txt file with each line of format:
# {model architecture name}: {OK or error message}
with open("simple_report.txt", "w") as fp:
fp.write(simple_report)
# The simplified failure report: same above except this only contains line with errors
with open("simple_failed_report.txt", "w") as fp:
fp.write(failed_report)
@ -1160,7 +1285,7 @@ if __name__ == "__main__":
try:
upload_model(model_dir, args.organization)
except Exception as e:
error = f"Failed to upload {model_dir}: {e}"
error = f"Failed to upload {model_dir}. {e.__class__.__name__}: {e}"
logger.error(error)
upload_results[model_dir] = error