[Docs] Benchmark docs (#5360)

* first doc version

* add benchmark docs

* fix typos

* improve README

* Update docs/source/benchmarks.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* fix naming and docs

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen 2020-06-29 16:08:57 +02:00 committed by GitHub
parent 482c9178d3
commit 4bcc35cd69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 373 additions and 109 deletions

View File

@ -1,54 +0,0 @@
# Benchmarks
This section is dedicated to the Benchmarks done by the library, both by maintainers, contributors and users. These
benchmark will help keep track of the preformance improvements that are brought to our models across versions.
## Benchmarking all models for inference
As of version 2.1 we have benchmarked all models for inference, across many different settings: using PyTorch, with
and without TorchScript, using TensorFlow, with and without XLA. All of those tests were done across CPUs (except for
TensorFlow XLA) and GPUs.
The approach is detailed in the [following blogpost](https://medium.com/huggingface/benchmarking-transformers-pytorch-and-tensorflow-e2917fb891c2)
The results are available [here](https://docs.google.com/spreadsheets/d/1sryqufw2D0XlUH4sq3e9Wnxu5EAQkaohzrJbd5HdQ_w/edit?usp=sharing).
## TF2 with mixed precision, XLA, Distribution (@tlkh)
This work was done by [Timothy Liu](https://github.com/tlkh).
There are very positive results to be gained from the various TensorFlow 2.0 features:
- Automatic Mixed Precision (AMP)
- XLA compiler
- Distribution strategies (multi-GPU)
The benefits are listed here (tested on CoLA, MRPC, SST-2):
- AMP: Between 1.4x to 1.6x decrease in overall time without change in batch size
- AMP+XLA: Up to 2.5x decrease in overall time on SST-2 (larger dataset)
- Distribution: Between 1.4x to 3.4x decrease in overall time on 4xV100
- Combined: Up to 5.7x decrease in overall training time, or 9.1x training throughput
The model quality (measured by the validation accuracy) fluctuates slightly. Taking an average of 4 training runs
on a single GPU gives the following results:
- CoLA: AMP results in slighter lower acc (0.820 vs 0.824)
- MRPC: AMP results in lower acc (0.823 vs 0.835)
- SST-2: AMP results in slighter lower acc (0.918 vs 0.922)
However, in a distributed setting with 4xV100 (4x batch size), AMP can yield in better results:
CoLA: AMP results in higher acc (0.828 vs 0.812)
MRPC: AMP results in lower acc (0.817 vs 0.827)
SST-2: AMP results in slightly lower acc (0.926 vs 0.929)
The benchmark script is available [here](https://github.com/NVAITC/benchmarking/blob/master/tf2/bert_dist.py).
Note: on some tasks (e.g. MRPC), the dataset is too small. The overhead due to the model compilation with XLA as well
as the distribution strategy setup does not speed things up. The XLA compile time is also the reason why although throughput
can increase a lot (e.g. 2.7x for single GPU), overall (end-to-end) training speed-up is not as fast (as low as 1.4x)
The benefits as seen on SST-2 (larger dataset) is much clear.
All results can be seen on this [Google Sheet](https://docs.google.com/spreadsheets/d/1538MN224EzjbRL239sqSiUy6YY-rAjHyXhTzz_Zptls/edit#gid=960868445).

322
docs/source/benchmarks.rst Normal file
View File

@ -0,0 +1,322 @@
Benchmarks
==========
Let's take a look at how 🤗 Transformer models can be benchmarked, best practices, and already available benchmarks.
A notebook explaining in more detail how to benchmark 🤗 Transformer models can be found `here <https://github.com/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb>`__.
How to benchmark 🤗 Transformer models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The classes :class:`~transformers.PyTorchBenchmark` and :class:`~transformers.TensorFlowBenchmark` allow to flexibly benchmark 🤗 Transformer models.
The benchmark classes allow us to measure the `peak memory usage` and `required time` for both
`inference` and `training`.
.. note::
Hereby, `inference` is defined by a single forward pass, and `training` is defined by a single forward pass and backward pass.
The benchmark classes :class:`~transformers.PyTorchBenchmark` and :class:`~transformers.TensorFlowBenchmark` expect an object of type :class:`~transformers.PyTorchBenchmarkArguments` and :class:`~transformers.TensorFlowBenchmarkArguments`, respectively, for instantiation. :class:`~transformers.PyTorchBenchmarkArguments` and :class:`~transformers.TensorFlowBenchmarkArguments` are data classes and contain all relevant configurations for their corresponding benchmark class.
In the following example, it is shown how a BERT model of type `bert-base-cased` can be benchmarked.
.. code-block::
>>> ## PYTORCH CODE
>>> from transformers import PyTorchBenchmark, PyTorchBenchmarkArguments
>>> args = PyTorchBenchmarkArguments(models=["bert-base-uncased"], batch_sizes=[8], sequence_lengths=[8, 32, 128, 512])
>>> benchmark = PyTorchBenchmark(args)
>>> ## TENSORFLOW CODE
>>> from transformers import TensorFlowBenchmark, TensorFlowBenchmarkArguments
>>> args = TensorFlowBenchmarkArguments(models=["bert-base-uncased"], batch_sizes=[8], sequence_lengths=[8, 32, 128, 512])
>>> benchmark = TensorFlowBenchmark(args)
Here, three arguments are given to the benchmark argument data classes, namely ``models``, ``batch_sizes``, and ``sequence_lengths``. The argument ``models`` is required and expects a :obj:`list` of model identifiers from the `model hub <https://huggingface.co/models>`__
The :obj:`list` arguments ``batch_sizes`` and ``sequence_lengths`` define the size of the ``input_ids`` on which the model is benchmarked.
There are many more parameters that can be configured via the benchmark argument data classes. For more detail on these one can either directly consult the files
``src/transformers/benchmark/benchmark_args_utils.py``, ``src/transformers/benchmark/benchmark_args.py`` (for PyTorch) and ``src/transformers/benchmark/benchmark_args_tf.py`` (for Tensorflow).
Alternatively, running the following shell commands from root will print out a descriptive list of all configurable parameters for PyTorch and Tensorflow respectively.
.. code-block::
>>> ## PYTORCH CODE
python examples/benchmarking/run_benchmark.py --help
>>> ## TENSORFLOW CODE
python examples/benchmarking/run_benchmark_tf.py --help
An instantiated benchmark object can then simply be run by calling ``benchmark.run()``.
.. code-block::
>>> ## PYTORCH CODE
>>> results = benchmark.run()
>>> print(results)
==================== INFERENCE - SPEED - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Time in s
--------------------------------------------------------------------------------
bert-base-uncased 8 8 0.006
bert-base-uncased 8 32 0.006
bert-base-uncased 8 128 0.018
bert-base-uncased 8 512 0.088
--------------------------------------------------------------------------------
==================== INFERENCE - MEMORY - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Memory in MB
--------------------------------------------------------------------------------
bert-base-uncased 8 8 1227
bert-base-uncased 8 32 1281
bert-base-uncased 8 128 1307
bert-base-uncased 8 512 1539
--------------------------------------------------------------------------------
==================== ENVIRONMENT INFORMATION ====================
- transformers_version: 2.11.0
- framework: PyTorch
- use_torchscript: False
- framework_version: 1.4.0
- python_version: 3.6.10
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-06-29
- time: 08:58:43.371351
- fp16: False
- use_multiprocessing: True
- only_pretrain_model: False
- cpu_ram_mb: 32088
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 2
- use_tpu: False
>>> ## TENSORFLOW CODE
>>> results = benchmark.run()
>>> print(results)
==================== INFERENCE - SPEED - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Time in s
--------------------------------------------------------------------------------
bert-base-uncased 8 8 0.005
bert-base-uncased 8 32 0.008
bert-base-uncased 8 128 0.022
bert-base-uncased 8 512 0.105
--------------------------------------------------------------------------------
==================== INFERENCE - MEMORY - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Memory in MB
--------------------------------------------------------------------------------
bert-base-uncased 8 8 1330
bert-base-uncased 8 32 1330
bert-base-uncased 8 128 1330
bert-base-uncased 8 512 1770
--------------------------------------------------------------------------------
==================== ENVIRONMENT INFORMATION ====================
- transformers_version: 2.11.0
- framework: Tensorflow
- use_xla: False
- framework_version: 2.2.0
- python_version: 3.6.10
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-06-29
- time: 09:26:35.617317
- fp16: False
- use_multiprocessing: True
- only_pretrain_model: False
- cpu_ram_mb: 32088
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 2
- use_tpu: False
By default, the `time` and the `required memory` for `inference` are benchmarked.
In the example output above the first two sections show the result corresponding to `inference time` and `inference memory`.
In addition, all relevant information about the computing environment, `e.g.` the GPU type, the system, the library versions, etc... are printed out in the third section under `ENVIRONMENT INFORMATION`.
This information can optionally be saved in a `.csv` file when adding the argument :obj:`save_to_csv=True` to :class:`~transformers.PyTorchBenchmarkArguments` and :class:`~transformers.TensorFlowBenchmarkArguments` respectively.
In this case, every section is saved in a separate `.csv` file. The path to each `.csv` file can optionally be defined via the argument data classes.
Instead of benchmarking pre-trained models via their model identifier, `e.g.` `bert-base-uncased`, the user can alternatively benchmark an arbitrary configuration of any available model class.
In this case, a :obj:`list` of configurations must be inserted with the benchmark args as follows.
.. code-block::
>>> ## PYTORCH CODE
>>> from transformers import PyTorchBenchmark, PyTorchBenchmarkArguments, BertConfig
>>> args = PyTorchBenchmarkArguments(models=["bert-base", "bert-384-hid", "bert-6-lay"], batch_sizes=[8], sequence_lengths=[8, 32, 128, 512])
>>> config_base = BertConfig()
>>> config_384_hid = BertConfig(hidden_size=384)
>>> config_6_lay = BertConfig(num_hidden_layers=6)
>>> benchmark = PyTorchBenchmark(args, configs=[config_base, config_384_hid, config_6_lay])
>>> benchmark.run()
==================== INFERENCE - SPEED - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Time in s
--------------------------------------------------------------------------------
bert-base 8 128 0.006
bert-base 8 512 0.006
bert-base 8 128 0.018
bert-base 8 512 0.088
bert-384-hid 8 8 0.006
bert-384-hid 8 32 0.006
bert-384-hid 8 128 0.011
bert-384-hid 8 512 0.054
bert-6-lay 8 8 0.003
bert-6-lay 8 32 0.004
bert-6-lay 8 128 0.009
bert-6-lay 8 512 0.044
--------------------------------------------------------------------------------
==================== INFERENCE - MEMORY - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Memory in MB
--------------------------------------------------------------------------------
bert-base 8 8 1277
bert-base 8 32 1281
bert-base 8 128 1307
bert-base 8 512 1539
bert-384-hid 8 8 1005
bert-384-hid 8 32 1027
bert-384-hid 8 128 1035
bert-384-hid 8 512 1255
bert-6-lay 8 8 1097
bert-6-lay 8 32 1101
bert-6-lay 8 128 1127
bert-6-lay 8 512 1359
--------------------------------------------------------------------------------
==================== ENVIRONMENT INFORMATION ====================
- transformers_version: 2.11.0
- framework: PyTorch
- use_torchscript: False
- framework_version: 1.4.0
- python_version: 3.6.10
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-06-29
- time: 09:35:25.143267
- fp16: False
- use_multiprocessing: True
- only_pretrain_model: False
- cpu_ram_mb: 32088
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 2
- use_tpu: False
>>> ## TENSORFLOW CODE
>>> from transformers import TensorFlowBenchmark, TensorFlowBenchmarkArguments, BertConfig
>>> args = TensorFlowBenchmarkArguments(models=["bert-base", "bert-384-hid", "bert-6-lay"], batch_sizes=[8], sequence_lengths=[8, 32, 128, 512])
>>> config_base = BertConfig()
>>> config_384_hid = BertConfig(hidden_size=384)
>>> config_6_lay = BertConfig(num_hidden_layers=6)
>>> benchmark = TensorFlowBenchmark(args, configs=[config_base, config_384_hid, config_6_lay])
>>> benchmark.run()
==================== INFERENCE - SPEED - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Time in s
--------------------------------------------------------------------------------
bert-base 8 8 0.005
bert-base 8 32 0.008
bert-base 8 128 0.022
bert-base 8 512 0.106
bert-384-hid 8 8 0.005
bert-384-hid 8 32 0.007
bert-384-hid 8 128 0.018
bert-384-hid 8 512 0.064
bert-6-lay 8 8 0.002
bert-6-lay 8 32 0.003
bert-6-lay 8 128 0.0011
bert-6-lay 8 512 0.074
--------------------------------------------------------------------------------
==================== INFERENCE - MEMORY - RESULT ====================
--------------------------------------------------------------------------------
Model Name Batch Size Seq Length Memory in MB
--------------------------------------------------------------------------------
bert-base 8 8 1330
bert-base 8 32 1330
bert-base 8 128 1330
bert-base 8 512 1770
bert-384-hid 8 8 1330
bert-384-hid 8 32 1330
bert-384-hid 8 128 1330
bert-384-hid 8 512 1540
bert-6-lay 8 8 1330
bert-6-lay 8 32 1330
bert-6-lay 8 128 1330
bert-6-lay 8 512 1540
--------------------------------------------------------------------------------
==================== ENVIRONMENT INFORMATION ====================
- transformers_version: 2.11.0
- framework: Tensorflow
- use_xla: False
- framework_version: 2.2.0
- python_version: 3.6.10
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-06-29
- time: 09:38:15.487125
- fp16: False
- use_multiprocessing: True
- only_pretrain_model: False
- cpu_ram_mb: 32088
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 2
- use_tpu: False
Again, `inference time` and `required memory` for `inference` are measured, but this time for customized configurations of the :obj:`BertModel` class. This feature can especially be helpful when
deciding for which configuration the model should be trained.
Benchmark best practices
~~~~~~~~~~~~~~~~~~~~~~~~
This section lists a couple of best practices one should be aware of when benchmarking a model.
- Currently, only single device benchmarking is supported. When benchmarking on GPU, it is recommended that the user
specifies on which device the code should be run by setting the ``CUDA_VISIBLE_DEVICES`` environment variable in the shell, `e.g.` ``export CUDA_VISIBLE_DEVICES=0`` before running the code.
- The option :obj:`no_multi_processing` should only be set to :obj:`True` for testing and debugging. To ensure accurate memory measurement it is recommended to run each memory benchmark in a separate process by making sure :obj:`no_multi_processing` is set to :obj:`True`.
- One should always state the environment information when sharing the results of a model benchmark. Results can vary heavily between different GPU devices, library versions, etc., so that benchmark results on their own are not very useful for the community.
Sharing your benchmark
~~~~~~~~~~~~~~~~~~~~~~
Previously all available core models (10 at the time) have been benchmarked for `inference time`, across many different settings: using PyTorch, with
and without TorchScript, using TensorFlow, with and without XLA. All of those tests were done across CPUs (except for
TensorFlow XLA) and GPUs.
The approach is detailed in the `following blogpost <https://medium.com/huggingface/benchmarking-transformers-pytorch-and-tensorflow-e2917fb891c2>`__ and the results are available `here <https://docs.google.com/spreadsheets/d/1sryqufw2D0XlUH4sq3e9Wnxu5EAQkaohzrJbd5HdQ_w/edit?usp=sharing>`__.
With the new `benchmark` tools, it is easier than ever to share your benchmark results with the community `here <https://github.com/huggingface/transformers/blob/master/examples/benchmarking/README.md>`__.

View File

@ -13,15 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training in Tensorflow"""
""" Benchmarking the library on inference and training in TensorFlow"""
from transformers import HfArgumentParser, TensorflowBenchmark, TensorflowBenchmarkArguments
from transformers import HfArgumentParser, TensorFlowBenchmark, TensorFlowBenchmarkArguments
def main():
parser = HfArgumentParser(TensorflowBenchmarkArguments)
parser = HfArgumentParser(TensorFlowBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = TensorflowBenchmark(args=benchmark_args)
benchmark = TensorFlowBenchmark(args=benchmark_args)
benchmark.run()

View File

@ -1,4 +0,0 @@
model,batch_size,sequence_length,result
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,8,512,0.2032
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,64,512,1.5279
aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2,256,512,6.1837
1 model batch_size sequence_length result
2 aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2 8 512 0.2032
3 aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2 64 512 1.5279
4 aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2 256 512 6.1837

View File

@ -289,7 +289,7 @@
"\n",
"Being able to accurately benchmark language models on both *speed* and *required memory* is therefore very important.\n",
"\n",
"HuggingFace's Transformer library allows users to benchmark models for both Tensorflow 2 and PyTorch using the `PyTorchBenchmark` and `TensorflowBenchmark` classes.\n",
"HuggingFace's Transformer library allows users to benchmark models for both TensorFlow 2 and PyTorch using the `PyTorchBenchmark` and `TensorFlowBenchmark` classes.\n",
"\n",
"The currently available features for `PyTorchBenchmark` are summarized in the following table.\n",
"\n",
@ -306,7 +306,7 @@
"\n",
"* *torchscript* corresponds to PyTorch's torchscript format, see [here](https://pytorch.org/docs/stable/jit.html).\n",
"\n",
"The currently available features for `TensorflowBenchmark` are summarized in the following table.\n",
"The currently available features for `TensorFlowBenchmark` are summarized in the following table.\n",
"\n",
"| | CPU | CPU + eager execution | GPU | GPU + eager execution | GPU + XLA | GPU + FP16 | TPU |\n",
":-- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n",
@ -315,16 +315,16 @@
"**Speed - Train** | ✘ | ✘ | ✘ | ✘ | ✘ | ✘ | ✘ |\n",
"**Memory - Train** | ✘ | ✘ | ✘ | ✘ | ✘ | ✘ | ✘ |\n",
"\n",
"* *eager execution* means that the function is run in the eager execution environment of Tensorflow 2, see [here](https://www.tensorflow.org/guide/eager).\n",
"* *eager execution* means that the function is run in the eager execution environment of TensorFlow 2, see [here](https://www.tensorflow.org/guide/eager).\n",
"\n",
"* *XLA* stands for Tensorflow's Accelerated Linear Algebra (XLA) compiler, see [here](https://www.tensorflow.org/xla)\n",
"* *XLA* stands for TensorFlow's Accelerated Linear Algebra (XLA) compiler, see [here](https://www.tensorflow.org/xla)\n",
"\n",
"* *FP16* stands for Tensorflow's mixed-precision package and is analogous to PyTorch's FP16 feature, see [here](https://www.tensorflow.org/guide/mixed_precision).\n",
"* *FP16* stands for TensorFlow's mixed-precision package and is analogous to PyTorch's FP16 feature, see [here](https://www.tensorflow.org/guide/mixed_precision).\n",
"\n",
"***Note***: In ~1,2 weeks it will also be possible to benchmark training in Tensorflow.\n",
"***Note***: In ~1,2 weeks it will also be possible to benchmark training in TensorFlow.\n",
"\n",
"\n",
"This notebook will show the user how to use `PyTorchBenchmark` and `TensorflowBenchmark` for two different scenarios:\n",
"This notebook will show the user how to use `PyTorchBenchmark` and `TensorFlowBenchmark` for two different scenarios:\n",
"\n",
"1. **Inference - Pre-trained Model Comparison** - *A user wants to implement a pre-trained model in production for inference. She wants to compare different models on speed and required memory.*\n",
"\n",
@ -443,7 +443,7 @@
"source": [
"Looks good! Now we import `transformers` and download the scripts `run_benchmark.py`, `run_benchmark_tf.py`, and `plot_csv_file.py` which can be found under `transformers/examples/benchmarking`.\n",
"\n",
"`run_benchmark_tf.py` and `run_benchmark.py` are very simple scripts leveraging the `PyTorchBenchmark` and `TensorflowBenchmark` classes, respectively."
"`run_benchmark_tf.py` and `run_benchmark.py` are very simple scripts leveraging the `PyTorchBenchmark` and `TensorFlowBenchmark` classes, respectively."
]
},
{
@ -482,7 +482,7 @@
"colab_type": "text"
},
"source": [
"Information about the input arguments to the *run_benchmark* scripts can be accessed by running `!python run_benchmark.py --help` for PyTorch and `!python run_benchmark_tf.py --help` for Tensorflow."
"Information about the input arguments to the *run_benchmark* scripts can be accessed by running `!python run_benchmark.py --help` for PyTorch and `!python run_benchmark_tf.py --help` for TensorFlow."
]
},
{
@ -1130,7 +1130,7 @@
},
"source": [
"At this point, it is important to understand how the peak memory is measured. The benchmarking tools measure the peak memory usage the same way the command `nvidia-smi` does - see [here](https://developer.nvidia.com/nvidia-system-management-interface) for more information. \n",
"In short, all memory that is allocated for a given *model identifier*, *batch size* and *sequence length* is measured in a separate process. This way it can be ensured that there is no previously unreleased memory falsely included in the measurement. One should also note that the measured memory even includes the memory allocated by the CUDA driver to load PyTorch and Tensorflow and is, therefore, higher than library-specific memory measurement function, *e.g.* this one for [PyTorch](https://pytorch.org/docs/stable/cuda.html#torch.cuda.max_memory_allocated).\n",
"In short, all memory that is allocated for a given *model identifier*, *batch size* and *sequence length* is measured in a separate process. This way it can be ensured that there is no previously unreleased memory falsely included in the measurement. One should also note that the measured memory even includes the memory allocated by the CUDA driver to load PyTorch and TensorFlow and is, therefore, higher than library-specific memory measurement function, *e.g.* this one for [PyTorch](https://pytorch.org/docs/stable/cuda.html#torch.cuda.max_memory_allocated).\n",
"\n",
"Alright, let's analyze the results. It can be noted that the models `aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2` and `deepset/roberta-base-squad2` require significantly less memory than the other three models. Besides `mrm8488/longformer-base-4096-finetuned-squadv2` all models more or less follow the same memory consumption pattern with `aodiniz/bert_uncased_L-10_H-512_A-8_cord19-200616_squad2` seemingly being able to better scale to larger sequence lengths. \n",
"`mrm8488/longformer-base-4096-finetuned-squadv2` is a *Longformer* model, which makes use of *LocalAttention* (check this blog post to learn more about local attention) so that the model scales much better to longer input sequences.\n",
@ -1256,7 +1256,7 @@
"source": [
"Interesting! `aodiniz/bert_uncased_L-10_H-51` clearly scales better for higher batch sizes and does not even run out of memory for 512 tokens.\n",
"\n",
"For comparison, let's run the same benchmarking on Tensorflow."
"For comparison, let's run the same benchmarking on TensorFlow."
]
},
{
@ -1341,7 +1341,7 @@
"colab_type": "text"
},
"source": [
"Let's see the same plot for Tensorflow."
"Let's see the same plot for TensorFlow."
]
},
{
@ -1394,7 +1394,7 @@
"colab_type": "text"
},
"source": [
"The model implemented in Tensorflow requires more memory than the one implemented in PyTorch. Let's say for whatever reason we have decided to use Tensorflow instead of PyTorch. \n",
"The model implemented in TensorFlow requires more memory than the one implemented in PyTorch. Let's say for whatever reason we have decided to use TensorFlow instead of PyTorch. \n",
"\n",
"The next step is to measure the inference time of these two models. Instead of disabling time measurement with `--no_speed`, we will now disable memory measurement with `--no_memory`."
]
@ -1499,7 +1499,7 @@
"source": [
"Ok, this took some time... time measurements take much longer than memory measurements because the forward pass is called multiple times for stable results. Timing measurements leverage Python's [timeit module](https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat) and run 10 times the value given to the `--repeat` argument (defaults to 3), so in our case 30 times.\n",
"\n",
"Let's focus on the resulting plot. It becomes obvious that `aodiniz/bert_uncased_L-10_H-51` is around twice as fast as `deepset/roberta-base-squad2`. Given that the model is also more memory efficient and assuming that the model performs reasonably well, for the sake of this notebook we will settle on `aodiniz/bert_uncased_L-10_H-51`. Our model should be able to process input sequences of up to 512 tokens. Latency time of around 2 seconds might be too long though, so let's compare the time for different batch sizes and using Tensorflows XLA package for more speed."
"Let's focus on the resulting plot. It becomes obvious that `aodiniz/bert_uncased_L-10_H-51` is around twice as fast as `deepset/roberta-base-squad2`. Given that the model is also more memory efficient and assuming that the model performs reasonably well, for the sake of this notebook we will settle on `aodiniz/bert_uncased_L-10_H-51`. Our model should be able to process input sequences of up to 512 tokens. Latency time of around 2 seconds might be too long though, so let's compare the time for different batch sizes and using TensorFlows XLA package for more speed."
]
},
{
@ -1551,7 +1551,7 @@
"colab_type": "text"
},
"source": [
"First of all, it can be noted that XLA reduces latency time by a factor of ca. 1.3 (which is more than observed for other models by Tensorflow [here](https://www.tensorflow.org/xla)). A batch size of 64 looks like a good choice. More or less half a second for the forward pass is good enough.\n",
"First of all, it can be noted that XLA reduces latency time by a factor of ca. 1.3 (which is more than observed for other models by TensorFlow [here](https://www.tensorflow.org/xla)). A batch size of 64 looks like a good choice. More or less half a second for the forward pass is good enough.\n",
"\n",
"Cool, now it should be straightforward to benchmark your favorite models. All the inference time measurements can also be done using the `run_benchmark.py` script for PyTorch."
]
@ -2021,4 +2021,4 @@
]
}
]
}
}

View File

@ -613,8 +613,8 @@ if is_tf_available():
from .trainer_tf import TFTrainer
# Benchmarks
from .benchmark.benchmark_tf import TensorflowBenchmark
from .benchmark.benchmark_args_tf import TensorflowBenchmarkArguments
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments
if not is_tf_available() and not is_torch_available():

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
@dataclass
class TensorflowBenchmarkArguments(BenchmarkArguments):
class TensorFlowBenchmarkArguments(BenchmarkArguments):
tpu_name: str = field(
default=None, metadata={"help": "Name of TPU"},
)

View File

@ -38,7 +38,7 @@ from .benchmark_utils import (
if is_tf_available():
import tensorflow as tf
from .benchmark_args_tf import TensorflowBenchmarkArguments
from .benchmark_args_tf import TensorFlowBenchmarkArguments
from tensorflow.python.framework.errors_impl import ResourceExhaustedError
if is_py3nvml_available():
@ -75,11 +75,11 @@ def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) ->
return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
class TensorflowBenchmark(Benchmark):
class TensorFlowBenchmark(Benchmark):
args: TensorflowBenchmarkArguments
args: TensorFlowBenchmarkArguments
configs: PretrainedConfig
framework: str = "Tensorflow"
framework: str = "TensorFlow"
@property
def framework_version(self):
@ -88,7 +88,7 @@ class TensorflowBenchmark(Benchmark):
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
# initialize GPU on separate process
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using Tensorflow."
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_speed(_inference)
@ -104,7 +104,7 @@ class TensorflowBenchmark(Benchmark):
if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using Tensorflow."
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_memory(_inference)
@ -166,7 +166,7 @@ class TensorflowBenchmark(Benchmark):
def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
logger.info(
"Note that Tensorflow allocates more memory than"
"Note that TensorFlow allocates more memory than"
"it might need to speed up computation."
"The memory reported here corresponds to the memory"
"reported by `nvidia-smi`, which can vary depending"
@ -210,7 +210,7 @@ class TensorflowBenchmark(Benchmark):
# cpu
if self.args.trace_memory_line_by_line:
logger.info(
"When enabling line by line tracing, the max peak memory for CPU is inaccurate in Tensorflow."
"When enabling line by line tracing, the max peak memory for CPU is inaccurate in TensorFlow."
)
memory = None
else:

View File

@ -740,7 +740,7 @@ class Benchmark(ABC):
info["framework"] = self.framework
if self.framework == "PyTorch":
info["use_torchscript"] = self.args.torchscript
if self.framework == "Tensorflow":
if self.framework == "TensorFlow":
info["eager_mode"] = self.args.eager_mode
info["use_xla"] = self.args.use_xla
info["framework_version"] = self.framework_version

View File

@ -10,7 +10,7 @@ from .utils import require_tf
if is_tf_available():
import tensorflow as tf
from transformers import TensorflowBenchmark, TensorflowBenchmarkArguments
from transformers import TensorFlowBenchmark, TensorFlowBenchmarkArguments
@require_tf
@ -23,7 +23,7 @@ class TFBenchmarkTest(unittest.TestCase):
def test_inference_no_configs_eager(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -32,14 +32,14 @@ class TFBenchmarkTest(unittest.TestCase):
eager_mode=True,
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args)
benchmark = TensorFlowBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_configs_only_pretrain(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -48,14 +48,14 @@ class TFBenchmarkTest(unittest.TestCase):
no_multi_process=True,
only_pretrain_model=True,
)
benchmark = TensorflowBenchmark(benchmark_args)
benchmark = TensorFlowBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
def test_inference_no_configs_graph(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -63,7 +63,7 @@ class TFBenchmarkTest(unittest.TestCase):
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args)
benchmark = TensorFlowBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
@ -71,7 +71,7 @@ class TFBenchmarkTest(unittest.TestCase):
def test_inference_with_configs_eager(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -80,7 +80,7 @@ class TFBenchmarkTest(unittest.TestCase):
eager_mode=True,
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args, [config])
benchmark = TensorFlowBenchmark(benchmark_args, [config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
@ -88,7 +88,7 @@ class TFBenchmarkTest(unittest.TestCase):
def test_inference_with_configs_graph(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -96,7 +96,7 @@ class TFBenchmarkTest(unittest.TestCase):
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args, [config])
benchmark = TensorFlowBenchmark(benchmark_args, [config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
@ -104,7 +104,7 @@ class TFBenchmarkTest(unittest.TestCase):
def test_inference_encoder_decoder_with_configs(self):
MODEL_ID = "patrickvonplaten/t5-tiny-random"
config = AutoConfig.from_pretrained(MODEL_ID)
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -112,7 +112,7 @@ class TFBenchmarkTest(unittest.TestCase):
batch_sizes=[1],
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args, configs=[config])
benchmark = TensorFlowBenchmark(benchmark_args, configs=[config])
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
@ -120,7 +120,7 @@ class TFBenchmarkTest(unittest.TestCase):
@unittest.skipIf(is_tf_available() and len(tf.config.list_physical_devices("GPU")) == 0, "Cannot do xla on CPU.")
def test_inference_no_configs_xla(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
@ -129,7 +129,7 @@ class TFBenchmarkTest(unittest.TestCase):
use_xla=True,
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args)
benchmark = TensorFlowBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)
@ -137,7 +137,7 @@ class TFBenchmarkTest(unittest.TestCase):
def test_save_csv_files(self):
MODEL_ID = "sshleifer/tiny-gpt2"
with tempfile.TemporaryDirectory() as tmp_dir:
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
no_inference=False,
save_to_csv=True,
@ -148,7 +148,7 @@ class TFBenchmarkTest(unittest.TestCase):
env_info_csv_file=os.path.join(tmp_dir, "env.csv"),
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args)
benchmark = TensorFlowBenchmark(benchmark_args)
benchmark.run()
self.assertTrue(Path(os.path.join(tmp_dir, "inf_time.csv")).exists())
self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists())
@ -164,7 +164,7 @@ class TFBenchmarkTest(unittest.TestCase):
self.assertTrue(hasattr(summary, "total"))
with tempfile.TemporaryDirectory() as tmp_dir:
benchmark_args = TensorflowBenchmarkArguments(
benchmark_args = TensorFlowBenchmarkArguments(
models=[MODEL_ID],
no_inference=False,
sequence_lengths=[8],
@ -175,7 +175,7 @@ class TFBenchmarkTest(unittest.TestCase):
eager_mode=True,
no_multi_process=True,
)
benchmark = TensorflowBenchmark(benchmark_args)
benchmark = TensorFlowBenchmark(benchmark_args)
result = benchmark.run()
_check_summary_is_not_empty(result.inference_summary)
self.assertTrue(Path(os.path.join(tmp_dir, "log.txt")).exists())