mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 13:50:13 +06:00
503 lines
21 KiB
Plaintext
503 lines
21 KiB
Plaintext
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Exporting transformers models
|
|
|
|
## ONNX / ONNXRuntime
|
|
|
|
Projects [ONNX (Open Neural Network eXchange)](http://onnx.ai) and [ONNXRuntime (ORT)](https://microsoft.github.io/onnxruntime/) are part of an effort from leading industries in the AI field to provide a
|
|
unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety
|
|
of hardware and dedicated optimizations.
|
|
|
|
|
|
Starting from transformers v2.10.0 we partnered with ONNX Runtime to provide an easy export of transformers models to
|
|
the ONNX format. You can have a look at the effort by looking at our joint blog post [Accelerate your NLP pipelines
|
|
using Hugging Face Transformers and ONNX Runtime](https://medium.com/microsoftazure/accelerate-your-nlp-pipelines-using-hugging-face-transformers-and-onnx-runtime-2443578f4333).
|
|
|
|
|
|
### Configuration-based approach
|
|
|
|
Transformers v4.9.0 introduces a new package: `transformers.onnx`. This package allows converting checkpoints to an
|
|
ONNX graph by leveraging configuration objects. These configuration objects come ready made for a number of model
|
|
architectures, and are made to be easily extendable to other architectures.
|
|
|
|
Ready-made configurations include the following models:
|
|
|
|
<!--This table is automatically generated by make style, do not fill manually!-->
|
|
|
|
- ALBERT
|
|
- BART
|
|
- BERT
|
|
- CamemBERT
|
|
- DistilBERT
|
|
- GPT Neo
|
|
- LayoutLM
|
|
- Longformer
|
|
- Marian
|
|
- mBART
|
|
- OpenAI GPT-2
|
|
- RoBERTa
|
|
- T5
|
|
- XLM-RoBERTa
|
|
|
|
This conversion is handled with the PyTorch version of models - it, therefore, requires PyTorch to be installed. If you
|
|
would like to be able to convert from TensorFlow, please let us know by opening an issue.
|
|
|
|
<Tip>
|
|
|
|
The models showcased here are close to fully feature complete, but do lack some features that are currently in
|
|
development. Namely, the ability to handle the past key values for decoder models is currently in the works.
|
|
|
|
</Tip>
|
|
|
|
#### Converting an ONNX model using the `transformers.onnx` package
|
|
|
|
The package may be used as a Python module:
|
|
|
|
```bash
|
|
python -m transformers.onnx --help
|
|
|
|
usage: Hugging Face ONNX Exporter tool [-h] -m MODEL -f {pytorch} [--features {default}] [--opset OPSET] [--atol ATOL] output
|
|
|
|
positional arguments:
|
|
output Path indicating where to store generated ONNX model.
|
|
|
|
optional arguments:
|
|
-h, --help show this help message and exit
|
|
-m MODEL, --model MODEL
|
|
Model's name of path on disk to load.
|
|
--features {default} Export the model with some additional features.
|
|
--opset OPSET ONNX opset version to export the model with (default 12).
|
|
--atol ATOL Absolute difference tolerance when validating the model.
|
|
```
|
|
|
|
Exporting a checkpoint using a ready-made configuration can be done as follows:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=bert-base-cased onnx/bert-base-cased/
|
|
```
|
|
|
|
This exports an ONNX graph of the mentioned checkpoint. Here it is *bert-base-cased*, but it can be any model from the
|
|
hub, or a local path.
|
|
|
|
It will be exported under `onnx/bert-base-cased`. You should see similar logs:
|
|
|
|
```bash
|
|
Validating ONNX model...
|
|
-[✓] ONNX model outputs' name match reference model ({'pooler_output', 'last_hidden_state'}
|
|
- Validating ONNX Model output "last_hidden_state":
|
|
-[✓] (2, 8, 768) matchs (2, 8, 768)
|
|
-[✓] all values close (atol: 0.0001)
|
|
- Validating ONNX Model output "pooler_output":
|
|
-[✓] (2, 768) matchs (2, 768)
|
|
-[✓] all values close (atol: 0.0001)
|
|
All good, model saved at: onnx/bert-base-cased/model.onnx
|
|
```
|
|
|
|
This export can now be used in the ONNX inference runtime:
|
|
|
|
```python
|
|
import onnxruntime as ort
|
|
|
|
from transformers import BertTokenizerFast
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
|
|
|
ort_session = ort.InferenceSession("onnx/bert-base-cased/model.onnx")
|
|
|
|
inputs = tokenizer("Using BERT in ONNX!", return_tensors="np")
|
|
outputs = ort_session.run(["last_hidden_state", "pooler_output"], dict(inputs))
|
|
```
|
|
|
|
The outputs used (`["last_hidden_state", "pooler_output"]`) can be obtained by taking a look at the ONNX
|
|
configuration of each model. For example, for BERT:
|
|
|
|
```python
|
|
from transformers.models.bert import BertOnnxConfig, BertConfig
|
|
|
|
config = BertConfig()
|
|
onnx_config = BertOnnxConfig(config)
|
|
output_keys = list(onnx_config.outputs.keys())
|
|
```
|
|
|
|
#### Implementing a custom configuration for an unsupported architecture
|
|
|
|
Let's take a look at the changes necessary to add a custom configuration for an unsupported architecture. Firstly, we
|
|
will need a custom ONNX configuration object that details the model inputs and outputs. The BERT ONNX configuration is
|
|
visible below:
|
|
|
|
```python
|
|
class BertOnnxConfig(OnnxConfig):
|
|
@property
|
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
return OrderedDict(
|
|
[
|
|
("input_ids", {0: "batch", 1: "sequence"}),
|
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
|
("token_type_ids", {0: "batch", 1: "sequence"}),
|
|
]
|
|
)
|
|
|
|
@property
|
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
|
```
|
|
|
|
Let's understand what's happening here. This configuration has two properties: the inputs, and the outputs.
|
|
|
|
The inputs return a dictionary, where each key corresponds to an expected input, and each value indicates the axis of
|
|
that input.
|
|
|
|
For BERT, there are three necessary inputs. These three inputs are of similar shape, which is made up of two
|
|
dimensions: the batch is the first dimension, and the second is the sequence.
|
|
|
|
The outputs return a similar dictionary, where, once again, each key corresponds to an expected output, and each value
|
|
indicates the axis of that output.
|
|
|
|
Once this is done, a single step remains: adding this configuration object to the initialisation of the model class,
|
|
and to the general `transformers` initialisation.
|
|
|
|
An important fact to notice is the use of *OrderedDict* in both inputs and outputs properties. This is a requirements
|
|
as inputs are matched against their relative position within the *PreTrainedModel.forward()* prototype and outputs are
|
|
match against there position in the returned *BaseModelOutputX* instance.
|
|
|
|
An example of such an addition is visible here, for the MBart model: [Making MBART ONNX-convertible](https://github.com/huggingface/transformers/pull/13049/commits/d097adcebd89a520f04352eb215a85916934204f)
|
|
|
|
If you would like to contribute your addition to the library, we recommend you implement tests. An example of such
|
|
tests is visible here: [Adding tests to the MBART ONNX conversion](https://github.com/huggingface/transformers/pull/13049/commits/5d642f65abf45ceeb72bd855ca7bfe2506a58e6a)
|
|
|
|
### Graph conversion
|
|
|
|
<Tip>
|
|
|
|
The approach detailed here is bing deprecated. We recommend you follow the part above for an up to date approach.
|
|
|
|
</Tip>
|
|
|
|
Exporting a model is done through the script *convert_graph_to_onnx.py* at the root of the transformers sources. The
|
|
following command shows how easy it is to export a BERT model from the library, simply run:
|
|
|
|
```bash
|
|
python convert_graph_to_onnx.py --framework <pt, tf> --model bert-base-cased bert-base-cased.onnx
|
|
```
|
|
|
|
The conversion tool works for both PyTorch and Tensorflow models and ensures:
|
|
|
|
- The model and its weights are correctly initialized from the Hugging Face model hub or a local checkpoint.
|
|
- The inputs and outputs are correctly generated to their ONNX counterpart.
|
|
- The generated model can be correctly loaded through onnxruntime.
|
|
|
|
<Tip>
|
|
|
|
Currently, inputs and outputs are always exported with dynamic sequence axes preventing some optimizations on the
|
|
ONNX Runtime. If you would like to see such support for fixed-length inputs/outputs, please open up an issue on
|
|
transformers.
|
|
|
|
</Tip>
|
|
|
|
Also, the conversion tool supports different options which let you tune the behavior of the generated model:
|
|
|
|
- **Change the target opset version of the generated model.** (More recent opset generally supports more operators and
|
|
enables faster inference)
|
|
|
|
- **Export pipeline-specific prediction heads.** (Allow to export model along with its task-specific prediction
|
|
head(s))
|
|
|
|
- **Use the external data format (PyTorch only).** (Lets you export model which size is above 2Gb ([More info](https://github.com/pytorch/pytorch/pull/33062)))
|
|
|
|
|
|
### Optimizations
|
|
|
|
ONNXRuntime includes some transformers-specific transformations to leverage optimized operations in the graph. Below
|
|
are some of the operators which can be enabled to speed up inference through ONNXRuntime (*see note below*):
|
|
|
|
- Constant folding
|
|
- Attention Layer fusing
|
|
- Skip connection LayerNormalization fusing
|
|
- FastGeLU approximation
|
|
|
|
Some of the optimizations performed by ONNX runtime can be hardware specific and thus lead to different performances if
|
|
used on another machine with a different hardware configuration than the one used for exporting the model. For this
|
|
reason, when using `convert_graph_to_onnx.py` optimizations are not enabled, ensuring the model can be easily
|
|
exported to various hardware. Optimizations can then be enabled when loading the model through ONNX runtime for
|
|
inference.
|
|
|
|
|
|
<Tip>
|
|
|
|
When quantization is enabled (see below), `convert_graph_to_onnx.py` script will enable optimizations on the
|
|
model because quantization would modify the underlying graph making it impossible for ONNX runtime to do the
|
|
optimizations afterwards.
|
|
|
|
</Tip>
|
|
|
|
<Tip>
|
|
|
|
For more information about the optimizations enabled by ONNXRuntime, please have a look at the [ONNXRuntime Github](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers).
|
|
|
|
</Tip>
|
|
|
|
### Quantization
|
|
|
|
ONNX exporter supports generating a quantized version of the model to allow efficient inference.
|
|
|
|
Quantization works by converting the memory representation of the parameters in the neural network to a compact integer
|
|
format. By default, weights of a neural network are stored as single-precision float (*float32*) which can express a
|
|
wide-range of floating-point numbers with decent precision. These properties are especially interesting at training
|
|
where you want fine-grained representation.
|
|
|
|
On the other hand, after the training phase, it has been shown one can greatly reduce the range and the precision of
|
|
*float32* numbers without changing the performances of the neural network.
|
|
|
|
More technically, *float32* parameters are converted to a type requiring fewer bits to represent each number, thus
|
|
reducing the overall size of the model. Here, we are enabling *float32* mapping to *int8* values (a non-floating,
|
|
single byte, number representation) according to the following formula:
|
|
|
|
$$y_{float32} = scale * x_{int8} - zero\_point$$
|
|
|
|
<Tip>
|
|
|
|
The quantization process will infer the parameter *scale* and *zero_point* from the neural network parameters
|
|
|
|
</Tip>
|
|
|
|
Leveraging tiny-integers has numerous advantages when it comes to inference:
|
|
|
|
- Storing fewer bits instead of 32 bits for the *float32* reduces the size of the model and makes it load faster.
|
|
- Integer operations execute a magnitude faster on modern hardware
|
|
- Integer operations require less power to do the computations
|
|
|
|
In order to convert a transformers model to ONNX IR with quantized weights you just need to specify `--quantize` when
|
|
using `convert_graph_to_onnx.py`. Also, you can have a look at the `quantize()` utility-method in this same script
|
|
file.
|
|
|
|
Example of quantized BERT model export:
|
|
|
|
```bash
|
|
python convert_graph_to_onnx.py --framework <pt, tf> --model bert-base-cased --quantize bert-base-cased.onnx
|
|
```
|
|
|
|
<Tip>
|
|
|
|
Quantization support requires ONNX Runtime >= 1.4.0
|
|
|
|
</Tip>
|
|
|
|
<Tip>
|
|
|
|
When exporting quantized model you will end up with two different ONNX files. The one specified at the end of the
|
|
above command will contain the original ONNX model storing *float32* weights. The second one, with `-quantized`
|
|
suffix, will hold the quantized parameters.
|
|
|
|
</Tip>
|
|
|
|
## TorchScript
|
|
|
|
<Tip>
|
|
|
|
This is the very beginning of our experiments with TorchScript and we are still exploring its capabilities with
|
|
variable-input-size models. It is a focus of interest to us and we will deepen our analysis in upcoming releases,
|
|
with more code examples, a more flexible implementation, and benchmarks comparing python-based codes with compiled
|
|
TorchScript.
|
|
|
|
</Tip>
|
|
|
|
According to Pytorch's documentation: "TorchScript is a way to create serializable and optimizable models from PyTorch
|
|
code". Pytorch's two modules [JIT and TRACE](https://pytorch.org/docs/stable/jit.html) allow the developer to export
|
|
their model to be re-used in other programs, such as efficiency-oriented C++ programs.
|
|
|
|
We have provided an interface that allows the export of 🤗 Transformers models to TorchScript so that they can be reused
|
|
in a different environment than a Pytorch-based python program. Here we explain how to export and use our models using
|
|
TorchScript.
|
|
|
|
Exporting a model requires two things:
|
|
|
|
- a forward pass with dummy inputs.
|
|
- model instantiation with the `torchscript` flag.
|
|
|
|
These necessities imply several things developers should be careful about. These are detailed below.
|
|
|
|
|
|
### Implications
|
|
|
|
### TorchScript flag and tied weights
|
|
|
|
This flag is necessary because most of the language models in this repository have tied weights between their
|
|
`Embedding` layer and their `Decoding` layer. TorchScript does not allow the export of models that have tied
|
|
weights, therefore it is necessary to untie and clone the weights beforehand.
|
|
|
|
This implies that models instantiated with the `torchscript` flag have their `Embedding` layer and `Decoding`
|
|
layer separate, which means that they should not be trained down the line. Training would de-synchronize the two
|
|
layers, leading to unexpected results.
|
|
|
|
This is not the case for models that do not have a Language Model head, as those do not have tied weights. These models
|
|
can be safely exported without the `torchscript` flag.
|
|
|
|
### Dummy inputs and standard lengths
|
|
|
|
The dummy inputs are used to do a model forward pass. While the inputs' values are propagating through the layers,
|
|
Pytorch keeps track of the different operations executed on each tensor. These recorded operations are then used to
|
|
create the "trace" of the model.
|
|
|
|
The trace is created relatively to the inputs' dimensions. It is therefore constrained by the dimensions of the dummy
|
|
input, and will not work for any other sequence length or batch size. When trying with a different size, an error such
|
|
as:
|
|
|
|
`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`
|
|
|
|
will be raised. It is therefore recommended to trace the model with a dummy input size at least as large as the largest
|
|
input that will be fed to the model during inference. Padding can be performed to fill the missing values. As the model
|
|
will have been traced with a large input size however, the dimensions of the different matrix will be large as well,
|
|
resulting in more calculations.
|
|
|
|
It is recommended to be careful of the total number of operations done on each input and to follow performance closely
|
|
when exporting varying sequence-length models.
|
|
|
|
### Using TorchScript in Python
|
|
|
|
Below is an example, showing how to save, load models as well as how to use the trace for inference.
|
|
|
|
#### Saving a model
|
|
|
|
This snippet shows how to use TorchScript to export a `BertModel`. Here the `BertModel` is instantiated according
|
|
to a `BertConfig` class and then saved to disk under the filename `traced_bert.pt`
|
|
|
|
```python
|
|
from transformers import BertModel, BertTokenizer, BertConfig
|
|
import torch
|
|
|
|
enc = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
|
|
# Tokenizing input text
|
|
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
|
|
tokenized_text = enc.tokenize(text)
|
|
|
|
# Masking one of the input tokens
|
|
masked_index = 8
|
|
tokenized_text[masked_index] = "[MASK]"
|
|
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
|
|
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
|
|
|
|
# Creating a dummy input
|
|
tokens_tensor = torch.tensor([indexed_tokens])
|
|
segments_tensors = torch.tensor([segments_ids])
|
|
dummy_input = [tokens_tensor, segments_tensors]
|
|
|
|
# Initializing the model with the torchscript flag
|
|
# Flag set to True even though it is not necessary as this model does not have an LM Head.
|
|
config = BertConfig(
|
|
vocab_size_or_config_json_file=32000,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
torchscript=True,
|
|
)
|
|
|
|
# Instantiating the model
|
|
model = BertModel(config)
|
|
|
|
# The model needs to be in evaluation mode
|
|
model.eval()
|
|
|
|
# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
|
|
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
|
|
|
|
# Creating the trace
|
|
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
|
torch.jit.save(traced_model, "traced_bert.pt")
|
|
```
|
|
|
|
#### Loading a model
|
|
|
|
This snippet shows how to load the `BertModel` that was previously saved to disk under the name `traced_bert.pt`.
|
|
We are re-using the previously initialised `dummy_input`.
|
|
|
|
```python
|
|
loaded_model = torch.jit.load("traced_bert.pt")
|
|
loaded_model.eval()
|
|
|
|
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
|
|
```
|
|
|
|
#### Using a traced model for inference
|
|
|
|
Using the traced model for inference is as simple as using its `__call__` dunder method:
|
|
|
|
```python
|
|
traced_model(tokens_tensor, segments_tensors)
|
|
```
|
|
|
|
### Deploying HuggingFace TorchScript models on AWS using the Neuron SDK
|
|
|
|
AWS introduced the [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/)
|
|
instance family for low cost, high performance machine learning inference in the cloud.
|
|
The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware accelerator,
|
|
specializing in deep learning inferencing workloads.
|
|
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#)
|
|
is the SDK for Inferentia that supports tracing and optimizing transformers models for
|
|
deployment on Inf1. The Neuron SDK provides:
|
|
|
|
|
|
1. Easy-to-use API with one line of code change to trace and optimize a TorchScript model for inference in the cloud.
|
|
2. Out of the box performance optimizations for [improved cost-performance](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/benchmark/>)
|
|
3. Support for HuggingFace transformers models built with either [PyTorch](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/bert_tutorial/tutorial_pretrained_bert.html)
|
|
or [TensorFlow](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/tensorflow/huggingface_bert/huggingface_bert.html).
|
|
|
|
#### Implications
|
|
|
|
Transformers Models based on the [BERT (Bidirectional Encoder Representations from Transformers)](https://huggingface.co/docs/transformers/master/model_doc/bert)
|
|
architecture, or its variants such as [distilBERT](https://huggingface.co/docs/transformers/master/model_doc/distilbert)
|
|
and [roBERTa](https://huggingface.co/docs/transformers/master/model_doc/roberta)
|
|
will run best on Inf1 for non-generative tasks such as Extractive Question Answering,
|
|
Sequence Classification, Token Classification. Alternatively, text generation
|
|
tasks can be adapted to run on Inf1, according to this [AWS Neuron MarianMT tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html).
|
|
More information about models that can be converted out of the box on Inferentia can be
|
|
found in the [Model Architecture Fit section of the Neuron documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/models/models-inferentia.html#models-inferentia).
|
|
|
|
#### Dependencies
|
|
|
|
Using AWS Neuron to convert models requires the following dependencies and environment:
|
|
|
|
* A [Neuron SDK environment](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/index.html#installation-guide),
|
|
which comes pre-configured on [AWS Deep Learning AMI](https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-inferentia-launching.html).
|
|
|
|
#### Converting a Model for AWS Neuron
|
|
|
|
Using the same script as in [Using TorchScript in Python](https://huggingface.co/docs/transformers/master/en/serialization#using-torchscript-in-python)
|
|
to trace a "BertModel", you import `torch.neuron` framework extension to access
|
|
the components of the Neuron SDK through a Python API.
|
|
|
|
```python
|
|
from transformers import BertModel, BertTokenizer, BertConfig
|
|
import torch
|
|
import torch.neuron
|
|
```
|
|
And only modify the tracing line of code
|
|
|
|
from:
|
|
|
|
```python
|
|
torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
|
```
|
|
|
|
to:
|
|
|
|
```python
|
|
torch.neuron.trace(model, [token_tensor, segments_tensors])
|
|
```
|
|
|
|
This change enables Neuron SDK to trace the model and optimize it to run in Inf1 instances.
|
|
|
|
To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates,
|
|
please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).
|