[Benchmarks] improve Example Plotter (#5245)

* improve plotting

* better labels

* fix time plot
This commit is contained in:
Patrick von Platen 2020-06-26 15:00:14 +02:00 committed by GitHub
parent 88d7f96e33
commit 79a82cc06a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 19 deletions

View File

@ -1,7 +1,7 @@
import csv
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional
from typing import List, Optional
import matplotlib.pyplot as plt
import numpy as np
@ -10,6 +10,10 @@ from matplotlib.ticker import ScalarFormatter
from transformers import HfArgumentParser
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class PlotArguments:
"""
@ -37,6 +41,25 @@ class PlotArguments:
figure_png_file: Optional[str] = field(
default=None, metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."},
)
short_model_names: Optional[List[str]] = list_field(
default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."}
)
def can_convert_to_int(string):
try:
int(string)
return True
except ValueError:
return False
def can_convert_to_float(string):
try:
float(string)
return True
except ValueError:
return False
class Plot:
@ -50,9 +73,16 @@ class Plot:
model_name = row["model"]
self.result_dict[model_name]["bsz"].append(int(row["batch_size"]))
self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"]))
self.result_dict[model_name]["result"][(int(row["batch_size"]), int(row["sequence_length"]))] = row[
"result"
]
if can_convert_to_int(row["result"]):
# value is not None
self.result_dict[model_name]["result"][
(int(row["batch_size"]), int(row["sequence_length"]))
] = int(row["result"])
elif can_convert_to_float(row["result"]):
# value is not None
self.result_dict[model_name]["result"][
(int(row["batch_size"]), int(row["sequence_length"]))
] = float(row["result"])
def plot(self):
fig, ax = plt.subplots()
@ -67,7 +97,7 @@ class Plot:
for axis in [ax.xaxis, ax.yaxis]:
axis.set_major_formatter(ScalarFormatter())
for model_name in self.result_dict.keys():
for model_name_idx, model_name in enumerate(self.result_dict.keys()):
batch_sizes = sorted(list(set(self.result_dict[model_name]["bsz"])))
sequence_lengths = sorted(list(set(self.result_dict[model_name]["seq_len"])))
results = self.result_dict[model_name]["result"]
@ -76,23 +106,33 @@ class Plot:
(batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes)
)
label_model_name = (
model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx]
)
for inner_loop_value in inner_loop_array:
if self.args.plot_along_batch:
y_axis_array = np.asarray([results[(x, inner_loop_value)] for x in x_axis_array], dtype=np.int)
y_axis_array = np.asarray(
[results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results],
dtype=np.int,
)
else:
y_axis_array = np.asarray([results[(inner_loop_value, x)] for x in x_axis_array], dtype=np.float32)
y_axis_array = np.asarray(
[results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results],
dtype=np.float32,
)
(x_axis_label, inner_loop_label) = (
("batch_size", "sequence_length in #tokens")
if self.args.plot_along_batch
else ("sequence_length in #tokens", "batch_size")
("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz")
)
x_axis_array = np.asarray(x_axis_array, np.int)
plt.scatter(x_axis_array, y_axis_array, label=f"{model_name} - {inner_loop_label}: {inner_loop_value}")
x_axis_array = np.asarray(x_axis_array, np.int)[: len(y_axis_array)]
plt.scatter(
x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}"
)
plt.plot(x_axis_array, y_axis_array, "--")
title_str += f" {model_name} vs."
title_str += f" {label_model_name} vs."
title_str = title_str[:-4]
y_axis_label = "Time in s" if self.args.is_time else "Memory in MB"

View File

@ -0,0 +1,4 @@
model,batch_size,sequence_length,result
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,8,512,0.2032
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,64,512,1.5279
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,256,512,6.1837
1 model batch_size sequence_length result
2 aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2 8 512 0.2032
3 aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2 64 512 1.5279
4 aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2 256 512 6.1837

View File

@ -74,12 +74,6 @@ class BenchmarkArguments:
"help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU."
},
)
with_lm_head: bool = field(
default=False,
metadata={
"help": "Use model with its language model head (MODEL_WITH_LM_HEAD_MAPPING instead of MODEL_MAPPING)"
},
)
inference_time_csv_file: str = field(
default=f"inference_time_{round(time())}.csv",
metadata={"help": "CSV filename used if saving time results to csv."},