mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[Benchmark] Extend Benchmark to all model type extensions (#5241)
* add benchmark for all kinds of models * improved import * delete bogus files * make style
This commit is contained in:
parent
7c41057d50
commit
9fe09cec76
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from matplotlib.ticker import ScalarFormatter
|
||||||
|
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
@ -24,6 +25,9 @@ class PlotArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."},
|
metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."},
|
||||||
)
|
)
|
||||||
|
no_log_scale: bool = field(
|
||||||
|
default=False, metadata={"help": "Disable logarithmic scale when plotting"},
|
||||||
|
)
|
||||||
is_train: bool = field(
|
is_train: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
@ -55,6 +59,14 @@ class Plot:
|
|||||||
title_str = "Time usage" if self.args.is_time else "Memory usage"
|
title_str = "Time usage" if self.args.is_time else "Memory usage"
|
||||||
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference"
|
title_str = title_str + " for training" if self.args.is_train else title_str + " for inference"
|
||||||
|
|
||||||
|
if not self.args.no_log_scale:
|
||||||
|
# set logarithm scales
|
||||||
|
ax.set_xscale("log")
|
||||||
|
ax.set_yscale("log")
|
||||||
|
|
||||||
|
for axis in [ax.xaxis, ax.yaxis]:
|
||||||
|
axis.set_major_formatter(ScalarFormatter())
|
||||||
|
|
||||||
for model_name in self.result_dict.keys():
|
for model_name in self.result_dict.keys():
|
||||||
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
|
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
|
||||||
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
|
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
|
||||||
@ -64,17 +76,12 @@ class Plot:
|
|||||||
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
|
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.xlim(min(x_axis_array), max(x_axis_array))
|
|
||||||
|
|
||||||
for inner_loop_value in inner_loop_array:
|
for inner_loop_value in inner_loop_array:
|
||||||
if self.args.plot_along_batch:
|
if self.args.plot_along_batch:
|
||||||
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
|
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
|
||||||
else:
|
else:
|
||||||
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
|
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
|
||||||
|
|
||||||
ax.set_xscale("log", basex=2)
|
|
||||||
ax.set_yscale("log", basey=10)
|
|
||||||
|
|
||||||
(x_axis_label, inner_loop_label) = (
|
(x_axis_label, inner_loop_label) = (
|
||||||
("batch_size", "sequence_length in #tokens")
|
("batch_size", "sequence_length in #tokens")
|
||||||
if self.args.plot_along_batch
|
if self.args.plot_along_batch
|
||||||
|
@ -87,8 +87,18 @@ class PyTorchBenchmark(Benchmark):
|
|||||||
|
|
||||||
if self.args.torchscript:
|
if self.args.torchscript:
|
||||||
config.torchscript = True
|
config.torchscript = True
|
||||||
if self.args.with_lm_head:
|
|
||||||
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
|
||||||
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
|
try:
|
||||||
|
model_class = config.architectures[0]
|
||||||
|
transformers_module = __import__("transformers", fromlist=[model_class])
|
||||||
|
model_cls = getattr(transformers_module, model_class)
|
||||||
|
model = model_cls(config)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = MODEL_MAPPING[config.__class__](config)
|
model = MODEL_MAPPING[config.__class__](config)
|
||||||
|
|
||||||
@ -127,7 +137,20 @@ class PyTorchBenchmark(Benchmark):
|
|||||||
|
|
||||||
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
||||||
config = self.config_dict[model_name]
|
config = self.config_dict[model_name]
|
||||||
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
|
||||||
|
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
|
||||||
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
|
try:
|
||||||
|
model_class = config.architectures[0]
|
||||||
|
transformers_module = __import__("transformers", fromlist=[model_class])
|
||||||
|
model_cls = getattr(transformers_module, model_class)
|
||||||
|
model = model_cls(config)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
||||||
|
|
||||||
if self.args.torchscript:
|
if self.args.torchscript:
|
||||||
raise NotImplementedError("Training for torchscript is currently not implemented")
|
raise NotImplementedError("Training for torchscript is currently not implemented")
|
||||||
|
@ -105,6 +105,12 @@ class BenchmarkArguments:
|
|||||||
metadata={"help": "Log filename used if print statements are saved in log."},
|
metadata={"help": "Log filename used if print statements are saved in log."},
|
||||||
)
|
)
|
||||||
repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
|
repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
|
||||||
|
only_pretrain_model: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain model weights."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
"""
|
"""
|
||||||
|
@ -24,13 +24,7 @@ import timeit
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from transformers import (
|
from transformers import TF_MODEL_MAPPING, PretrainedConfig, is_py3nvml_available, is_tf_available
|
||||||
TF_MODEL_MAPPING,
|
|
||||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
|
||||||
PretrainedConfig,
|
|
||||||
is_py3nvml_available,
|
|
||||||
is_tf_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .benchmark_utils import (
|
from .benchmark_utils import (
|
||||||
Benchmark,
|
Benchmark,
|
||||||
@ -125,8 +119,17 @@ class TensorflowBenchmark(Benchmark):
|
|||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
raise NotImplementedError("Mixed precision is currently not supported.")
|
raise NotImplementedError("Mixed precision is currently not supported.")
|
||||||
|
|
||||||
if self.args.with_lm_head:
|
has_model_class_in_config = hasattr(config, "architecture") and len(config.architectures) > 1
|
||||||
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
if not self.args.only_pretrain_model and has_model_class_in_config:
|
||||||
|
try:
|
||||||
|
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
||||||
|
transformers_module = __import__("transformers", fromlist=[model_class])
|
||||||
|
model_cls = getattr(transformers_module, model_class)
|
||||||
|
model = model_cls(config)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = TF_MODEL_MAPPING[config.__class__](config)
|
model = TF_MODEL_MAPPING[config.__class__](config)
|
||||||
|
|
||||||
|
@ -752,6 +752,7 @@ class Benchmark(ABC):
|
|||||||
info["time"] = datetime.time(datetime.now())
|
info["time"] = datetime.time(datetime.now())
|
||||||
info["fp16"] = self.args.fp16
|
info["fp16"] = self.args.fp16
|
||||||
info["use_multiprocessing"] = self.args.do_multi_processing
|
info["use_multiprocessing"] = self.args.do_multi_processing
|
||||||
|
info["only_pretrain_model"] = self.args.only_pretrain_model
|
||||||
|
|
||||||
if is_psutil_available():
|
if is_psutil_available():
|
||||||
info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
|
info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
|
||||||
|
@ -38,6 +38,22 @@ class BenchmarkTest(unittest.TestCase):
|
|||||||
self.check_results_dict_not_empty(results.time_inference_result)
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
|
def test_inference_no_configs_only_pretrain(self):
|
||||||
|
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||||
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
models=[MODEL_ID],
|
||||||
|
training=False,
|
||||||
|
no_inference=False,
|
||||||
|
sequence_lengths=[8],
|
||||||
|
batch_sizes=[1],
|
||||||
|
no_multi_process=True,
|
||||||
|
only_pretrain_model=True,
|
||||||
|
)
|
||||||
|
benchmark = PyTorchBenchmark(benchmark_args)
|
||||||
|
results = benchmark.run()
|
||||||
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
def test_inference_torchscript(self):
|
def test_inference_torchscript(self):
|
||||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
benchmark_args = PyTorchBenchmarkArguments(
|
benchmark_args = PyTorchBenchmarkArguments(
|
||||||
|
@ -37,6 +37,22 @@ class TFBenchmarkTest(unittest.TestCase):
|
|||||||
self.check_results_dict_not_empty(results.time_inference_result)
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
self.check_results_dict_not_empty(results.memory_inference_result)
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
|
def test_inference_no_configs_only_pretrain(self):
|
||||||
|
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||||
|
benchmark_args = TensorflowBenchmarkArguments(
|
||||||
|
models=[MODEL_ID],
|
||||||
|
training=False,
|
||||||
|
no_inference=False,
|
||||||
|
sequence_lengths=[8],
|
||||||
|
batch_sizes=[1],
|
||||||
|
no_multi_process=True,
|
||||||
|
only_pretrain_model=True,
|
||||||
|
)
|
||||||
|
benchmark = TensorflowBenchmark(benchmark_args)
|
||||||
|
results = benchmark.run()
|
||||||
|
self.check_results_dict_not_empty(results.time_inference_result)
|
||||||
|
self.check_results_dict_not_empty(results.memory_inference_result)
|
||||||
|
|
||||||
def test_inference_no_configs_graph(self):
|
def test_inference_no_configs_graph(self):
|
||||||
MODEL_ID = "sshleifer/tiny-gpt2"
|
MODEL_ID = "sshleifer/tiny-gpt2"
|
||||||
benchmark_args = TensorflowBenchmarkArguments(
|
benchmark_args = TensorflowBenchmarkArguments(
|
||||||
|
Loading…
Reference in New Issue
Block a user