mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Benchmarks] improve Example Plotter (#5245)
* improve plotting * better labels * fix time plot
This commit is contained in:
parent
88d7f96e33
commit
79a82cc06a
@ -1,7 +1,7 @@
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@ -10,6 +10,10 @@ from matplotlib.ticker import ScalarFormatter
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
return field(default_factory=lambda: default, metadata=metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotArguments:
|
||||
"""
|
||||
@ -37,6 +41,25 @@ class PlotArguments:
|
||||
figure_png_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
|
||||
)
|
||||
short_model_names: Optional[List[str]] = list_field(
|
||||
default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."}
|
||||
)
|
||||
|
||||
|
||||
def can_convert_to_int(string):
|
||||
try:
|
||||
int(string)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def can_convert_to_float(string):
|
||||
try:
|
||||
float(string)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class Plot:
|
||||
@ -50,9 +73,16 @@ class Plot:
|
||||
model_name = row["model"]
|
||||
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
|
||||
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
|
||||
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
|
||||
"result"
|
||||
]
|
||||
if can_convert_to_int(row["result"]):
|
||||
# value is not None
|
||||
self.result_dict[model_name]["result"][
|
||||
(int(row["batch_size"]), int(row["sequence_length"]))
|
||||
] = int(row["result"])
|
||||
elif can_convert_to_float(row["result"]):
|
||||
# value is not None
|
||||
self.result_dict[model_name]["result"][
|
||||
(int(row["batch_size"]), int(row["sequence_length"]))
|
||||
] = float(row["result"])
|
||||
|
||||
def plot(self):
|
||||
fig, ax = plt.subplots()
|
||||
@ -67,7 +97,7 @@ class Plot:
|
||||
for axis in [ax.xaxis, ax.yaxis]:
|
||||
axis.set_major_formatter(ScalarFormatter())
|
||||
|
||||
for model_name in self.result_dict.keys():
|
||||
for model_name_idx, model_name in enumerate(self.result_dict.keys()):
|
||||
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
|
||||
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
|
||||
results = self.result_dict[model_name]["result"]
|
||||
@ -76,23 +106,33 @@ class Plot:
|
||||
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
|
||||
)
|
||||
|
||||
label_model_name = (
|
||||
model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx]
|
||||
)
|
||||
|
||||
for inner_loop_value in inner_loop_array:
|
||||
if self.args.plot_along_batch:
|
||||
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
|
||||
y_axis_array = np.asarray(
|
||||
[results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results],
|
||||
dtype=np.int,
|
||||
)
|
||||
else:
|
||||
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
|
||||
y_axis_array = np.asarray(
|
||||
[results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
(x_axis_label, inner_loop_label) = (
|
||||
("batch_size", "sequence_length in #tokens")
|
||||
if self.args.plot_along_batch
|
||||
else ("sequence_length in #tokens", "batch_size")
|
||||
("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz")
|
||||
)
|
||||
|
||||
x_axis_array = np.asarray(x_axis_array, np.int)
|
||||
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
|
||||
x_axis_array = np.asarray(x_axis_array, np.int)[: len(y_axis_array)]
|
||||
plt.scatter(
|
||||
x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}"
|
||||
)
|
||||
plt.plot(x_axis_array, y_axis_array, "--")
|
||||
|
||||
title_str += f" {model_name} vs."
|
||||
title_str += f" {label_model_name} vs."
|
||||
|
||||
title_str = title_str[:-4]
|
||||
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"
|
||||
|
4
examples/benchmarking/time_xla_1.csv
Normal file
4
examples/benchmarking/time_xla_1.csv
Normal file
@ -0,0 +1,4 @@
|
||||
model,batch_size,sequence_length,result
|
||||
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,8,512,0.2032
|
||||
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,64,512,1.5279
|
||||
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,256,512,6.1837
|
|
@ -74,12 +74,6 @@ class BenchmarkArguments:
|
||||
"help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
|
||||
},
|
||||
)
|
||||
with_lm_head: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use model with its language model head (MODEL_WITH_LM_HEAD_MAPPING instead of MODEL_MAPPING)"
|
||||
},
|
||||
)
|
||||
inference_time_csv_file: str = field(
|
||||
default=f"inference_time_{round(time())}.csv",
|
||||
metadata={"help": "CSV filename used if saving time results to csv."},
|
||||
|
Loading…
Reference in New Issue
Block a user