Breakup export guide (#19271)

* split onnx and torchscript docs

* make style

* apply reviews
This commit is contained in:
Steven Liu 2022-10-03 13:18:29 -07:00 committed by GitHub
parent 18c06208c4
commit 68f50f3453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 347 additions and 327 deletions

View File

@ -33,7 +33,9 @@
- local: converting_tensorflow_models
title: Converting from TensorFlow checkpoints
- local: serialization
title: Export 🤗 Transformers models
title: Export to ONNX
- local: torchscript
title: Export to TorchScript
- local: troubleshooting
title: Troubleshoot
title: General usage

View File

@ -10,36 +10,36 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Export 🤗 Transformers Models
# Export to ONNX
If you need to deploy 🤗 Transformers models in production environments, we
recommend exporting them to a serialized format that can be loaded and executed
on specialized runtimes and hardware. In this guide, we'll show you how to
export 🤗 Transformers models in two widely used formats: ONNX and TorchScript.
If you need to deploy 🤗 Transformers models in production environments, we recommend
exporting them to a serialized format that can be loaded and executed on specialized
runtimes and hardware. In this guide, we'll show you how to export 🤗 Transformers
models to [ONNX (Open Neural Network eXchange)](http://onnx.ai).
Once exported, a model can optimized for inference via techniques such as
quantization and pruning. If you are interested in optimizing your models to run
with maximum efficiency, check out the [🤗 Optimum
<Tip>
Once exported, a model can be optimized for inference via techniques such as
quantization and pruning. If you are interested in optimizing your models to run with
maximum efficiency, check out the [🤗 Optimum
library](https://github.com/huggingface/optimum).
## ONNX
</Tip>
The [ONNX (Open Neural Network eXchange)](http://onnx.ai) project is an open
standard that defines a common set of operators and a common file format to
represent deep learning models in a wide variety of frameworks, including
PyTorch and TensorFlow. When a model is exported to the ONNX format, these
operators are used to construct a computational graph (often called an
_intermediate representation_) which represents the flow of data through the
neural network.
ONNX is an open standard that defines a common set of operators and a common file format
to represent deep learning models in a wide variety of frameworks, including PyTorch and
TensorFlow. When a model is exported to the ONNX format, these operators are used to
construct a computational graph (often called an _intermediate representation_) which
represents the flow of data through the neural network.
By exposing a graph with standardized operators and data types, ONNX makes it
easy to switch between frameworks. For example, a model trained in PyTorch can
be exported to ONNX format and then imported in TensorFlow (and vice versa).
By exposing a graph with standardized operators and data types, ONNX makes it easy to
switch between frameworks. For example, a model trained in PyTorch can be exported to
ONNX format and then imported in TensorFlow (and vice versa).
🤗 Transformers provides a `transformers.onnx` package that enables you to
convert model checkpoints to an ONNX graph by leveraging configuration objects.
These configuration objects come ready made for a number of model architectures,
and are designed to be easily extendable to other architectures.
🤗 Transformers provides a [`transformers.onnx`](main_classes/onnx) package that enables
you to convert model checkpoints to an ONNX graph by leveraging configuration objects.
These configuration objects come ready made for a number of model architectures, and are
designed to be easily extendable to other architectures.
Ready-made configurations include the following architectures:
@ -106,10 +106,10 @@ In the next two sections, we'll show you how to:
* Export a supported model using the `transformers.onnx` package.
* Export a custom model for an unsupported architecture.
### Exporting a model to ONNX
## Exporting a model to ONNX
To export a 🤗 Transformers model to ONNX, you'll first need to install some
extra dependencies:
To export a 🤗 Transformers model to ONNX, you'll first need to install some extra
dependencies:
```bash
pip install transformers[onnx]
@ -141,7 +141,7 @@ Exporting a checkpoint using a ready-made configuration can be done as follows:
python -m transformers.onnx --model=distilbert-base-uncased onnx/
```
which should show the following logs:
You should see the following logs:
```bash
Validating ONNX model...
@ -152,13 +152,13 @@ Validating ONNX model...
All good, model saved at: onnx/model.onnx
```
This exports an ONNX graph of the checkpoint defined by the `--model` argument.
In this example it is `distilbert-base-uncased`, but it can be any checkpoint on
the Hugging Face Hub or one that's stored locally.
This exports an ONNX graph of the checkpoint defined by the `--model` argument. In this
example, it is `distilbert-base-uncased`, but it can be any checkpoint on the Hugging
Face Hub or one that's stored locally.
The resulting `model.onnx` file can then be run on one of the [many
accelerators](https://onnx.ai/supported-tools.html#deployModel) that support the
ONNX standard. For example, we can load and run the model with [ONNX
accelerators](https://onnx.ai/supported-tools.html#deployModel) that support the ONNX
standard. For example, we can load and run the model with [ONNX
Runtime](https://onnxruntime.ai/) as follows:
```python
@ -172,9 +172,8 @@ Runtime](https://onnxruntime.ai/) as follows:
>>> outputs = session.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
```
The required output names (i.e. `["last_hidden_state"]`) can be obtained by
taking a look at the ONNX configuration of each model. For example, for
DistilBERT we have:
The required output names (like `["last_hidden_state"]`) can be obtained by taking a
look at the ONNX configuration of each model. For example, for DistilBERT we have:
```python
>>> from transformers.models.distilbert import DistilBertConfig, DistilBertOnnxConfig
@ -185,20 +184,19 @@ DistilBERT we have:
["last_hidden_state"]
```
The process is identical for TensorFlow checkpoints on the Hub. For example, we
can export a pure TensorFlow checkpoint from the [Keras
The process is identical for TensorFlow checkpoints on the Hub. For example, we can
export a pure TensorFlow checkpoint from the [Keras
organization](https://huggingface.co/keras-io) as follows:
```bash
python -m transformers.onnx --model=keras-io/transformers-qa onnx/
```
To export a model that's stored locally, you'll need to have the model's weights
and tokenizer files stored in a directory. For example, we can load and save a
checkpoint as follows:
To export a model that's stored locally, you'll need to have the model's weights and
tokenizer files stored in a directory. For example, we can load and save a checkpoint as
follows:
<frameworkcontent>
<pt>
<frameworkcontent> <pt>
```python
>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
@ -216,8 +214,7 @@ argument of the `transformers.onnx` package to the desired directory:
```bash
python -m transformers.onnx --model=local-pt-checkpoint onnx/
```
</pt>
<tf>
</pt> <tf>
```python
>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
@ -235,14 +232,13 @@ argument of the `transformers.onnx` package to the desired directory:
```bash
python -m transformers.onnx --model=local-tf-checkpoint onnx/
```
</tf>
</frameworkcontent>
</tf> </frameworkcontent>
### Selecting features for different model topologies
## Selecting features for different model tasks
Each ready-made configuration comes with a set of _features_ that enable you to
export models for different types of topologies or tasks. As shown in the table
below, each feature is associated with a different auto class:
Each ready-made configuration comes with a set of _features_ that enable you to export
models for different types of tasks. As shown in the table below, each feature is
associated with a different `AutoClass`:
| Feature | Auto Class |
| ------------------------------------ | ------------------------------------ |
@ -255,7 +251,7 @@ below, each feature is associated with a different auto class:
| `token-classification` | `AutoModelForTokenClassification` |
For each configuration, you can find the list of supported features via the
`FeaturesManager`. For example, for DistilBERT we have:
[`~transformers.onnx.FeaturesManager`]. For example, for DistilBERT we have:
```python
>>> from transformers.onnx.features import FeaturesManager
@ -266,15 +262,15 @@ For each configuration, you can find the list of supported features via the
```
You can then pass one of these features to the `--feature` argument in the
`transformers.onnx` package. For example, to export a text-classification model
we can pick a fine-tuned model from the Hub and run:
`transformers.onnx` package. For example, to export a text-classification model we can
pick a fine-tuned model from the Hub and run:
```bash
python -m transformers.onnx --model=distilbert-base-uncased-finetuned-sst-2-english \
--feature=sequence-classification onnx/
```
which will display the following logs:
This displays the following logs:
```bash
Validating ONNX model...
@ -285,37 +281,35 @@ Validating ONNX model...
All good, model saved at: onnx/model.onnx
```
Notice that in this case, the output names from the fine-tuned model are
`logits` instead of the `last_hidden_state` we saw with the
`distilbert-base-uncased` checkpoint earlier. This is expected since the
fine-tuned model has a sequence classification head.
Notice that in this case, the output names from the fine-tuned model are `logits`
instead of the `last_hidden_state` we saw with the `distilbert-base-uncased` checkpoint
earlier. This is expected since the fine-tuned model has a sequence classification head.
<Tip>
The features that have a `with-past` suffix (e.g. `causal-lm-with-past`)
correspond to model topologies with precomputed hidden states (key and values
in the attention blocks) that can be used for fast autoregressive decoding.
The features that have a `with-past` suffix (like `causal-lm-with-past`) correspond to
model classes with precomputed hidden states (key and values in the attention blocks)
that can be used for fast autoregressive decoding.
</Tip>
### Exporting a model for an unsupported architecture
## Exporting a model for an unsupported architecture
If you wish to export a model whose architecture is not natively supported by
the library, there are three main steps to follow:
If you wish to export a model whose architecture is not natively supported by the
library, there are three main steps to follow:
1. Implement a custom ONNX configuration.
2. Export the model to ONNX.
3. Validate the outputs of the PyTorch and exported models.
In this section, we'll look at how DistilBERT was implemented to show what's
involved with each step.
In this section, we'll look at how DistilBERT was implemented to show what's involved
with each step.
#### Implementing a custom ONNX configuration
### Implementing a custom ONNX configuration
Let's start with the ONNX configuration object. We provide three abstract
classes that you should inherit from, depending on the type of model
architecture you wish to export:
Let's start with the ONNX configuration object. We provide three abstract classes that
you should inherit from, depending on the type of model architecture you wish to export:
* Encoder-based models inherit from [`~onnx.config.OnnxConfig`]
* Decoder-based models inherit from [`~onnx.config.OnnxConfigWithPast`]
@ -347,25 +341,24 @@ Since DistilBERT is an encoder-based model, its configuration inherits from
... )
```
Every configuration object must implement the `inputs` property and return a
mapping, where each key corresponds to an expected input, and each value
indicates the axis of that input. For DistilBERT, we can see that two inputs are
required: `input_ids` and `attention_mask`. These inputs have the same shape of
`(batch_size, sequence_length)` which is why we see the same axes used in the
configuration.
Every configuration object must implement the `inputs` property and return a mapping,
where each key corresponds to an expected input, and each value indicates the axis of
that input. For DistilBERT, we can see that two inputs are required: `input_ids` and
`attention_mask`. These inputs have the same shape of `(batch_size, sequence_length)`
which is why we see the same axes used in the configuration.
<Tip>
Notice that `inputs` property for `DistilBertOnnxConfig` returns an
`OrderedDict`. This ensures that the inputs are matched with their relative
position within the `PreTrainedModel.forward()` method when tracing the graph.
We recommend using an `OrderedDict` for the `inputs` and `outputs` properties
when implementing custom ONNX configurations.
Notice that `inputs` property for `DistilBertOnnxConfig` returns an `OrderedDict`. This
ensures that the inputs are matched with their relative position within the
`PreTrainedModel.forward()` method when tracing the graph. We recommend using an
`OrderedDict` for the `inputs` and `outputs` properties when implementing custom ONNX
configurations.
</Tip>
Once you have implemented an ONNX configuration, you can instantiate it by
providing the base model's configuration as follows:
Once you have implemented an ONNX configuration, you can instantiate it by providing the
base model's configuration as follows:
```python
>>> from transformers import AutoConfig
@ -374,8 +367,8 @@ providing the base model's configuration as follows:
>>> onnx_config = DistilBertOnnxConfig(config)
```
The resulting object has several useful properties. For example you can view the
ONNX operator set that will be used during the export:
The resulting object has several useful properties. For example, you can view the ONNX
operator set that will be used during the export:
```python
>>> print(onnx_config.default_onnx_opset)
@ -389,15 +382,14 @@ You can also view the outputs associated with the model as follows:
OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
```
Notice that the outputs property follows the same structure as the inputs; it
returns an `OrderedDict` of named outputs and their shapes. The output structure
is linked to the choice of feature that the configuration is initialised with.
By default, the ONNX configuration is initialized with the `default` feature
that corresponds to exporting a model loaded with the `AutoModel` class. If you
want to export a different model topology, just provide a different feature to
the `task` argument when you initialize the ONNX configuration. For example, if
we wished to export DistilBERT with a sequence classification head, we could
use:
Notice that the outputs property follows the same structure as the inputs; it returns an
`OrderedDict` of named outputs and their shapes. The output structure is linked to the
choice of feature that the configuration is initialised with. By default, the ONNX
configuration is initialized with the `default` feature that corresponds to exporting a
model loaded with the `AutoModel` class. If you want to export a model for another task,
just provide a different feature to the `task` argument when you initialize the ONNX
configuration. For example, if we wished to export DistilBERT with a sequence
classification head, we could use:
```python
>>> from transformers import AutoConfig
@ -410,18 +402,18 @@ OrderedDict([('logits', {0: 'batch'})])
<Tip>
All of the base properties and methods associated with [`~onnx.config.OnnxConfig`] and the
other configuration classes can be overriden if needed. Check out
[`BartOnnxConfig`] for an advanced example.
All of the base properties and methods associated with [`~onnx.config.OnnxConfig`] and
the other configuration classes can be overriden if needed. Check out [`BartOnnxConfig`]
for an advanced example.
</Tip>
#### Exporting the model
### Exporting the model
Once you have implemented the ONNX configuration, the next step is to export the
model. Here we can use the `export()` function provided by the
`transformers.onnx` package. This function expects the ONNX configuration, along
with the base model and tokenizer, and the path to save the exported file:
Once you have implemented the ONNX configuration, the next step is to export the model.
Here we can use the `export()` function provided by the `transformers.onnx` package.
This function expects the ONNX configuration, along with the base model and tokenizer,
and the path to save the exported file:
```python
>>> from pathlib import Path
@ -436,10 +428,9 @@ with the base model and tokenizer, and the path to save the exported file:
>>> onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
```
The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are
lists of the keys defined in the `inputs` and `outputs` properties of the
configuration. Once the model is exported, you can test that the model is well
formed as follows:
The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are lists of
the keys defined in the `inputs` and `outputs` properties of the configuration. Once the
model is exported, you can test that the model is well formed as follows:
```python
>>> import onnx
@ -450,21 +441,20 @@ formed as follows:
<Tip>
If your model is larger than 2GB, you will see that many additional files are
created during the export. This is _expected_ because ONNX uses [Protocol
Buffers](https://developers.google.com/protocol-buffers/) to store the model and
these have a size limit of 2GB. See the [ONNX
documentation](https://github.com/onnx/onnx/blob/master/docs/ExternalData.md)
for instructions on how to load models with external data.
If your model is larger than 2GB, you will see that many additional files are created
during the export. This is _expected_ because ONNX uses [Protocol
Buffers](https://developers.google.com/protocol-buffers/) to store the model and these
have a size limit of 2GB. See the [ONNX
documentation](https://github.com/onnx/onnx/blob/master/docs/ExternalData.md) for
instructions on how to load models with external data.
</Tip>
#### Validating the model outputs
### Validating the model outputs
The final step is to validate that the outputs from the base and exported model
agree within some absolute tolerance. Here we can use the
`validate_model_outputs()` function provided by the `transformers.onnx` package
as follows:
The final step is to validate that the outputs from the base and exported model agree
within some absolute tolerance. Here we can use the `validate_model_outputs()` function
provided by the `transformers.onnx` package as follows:
```python
>>> from transformers.onnx import validate_model_outputs
@ -474,220 +464,23 @@ as follows:
... )
```
This function uses the `OnnxConfig.generate_dummy_inputs()` method to generate
inputs for the base and exported model, and the absolute tolerance can be
defined in the configuration. We generally find numerical agreement in the 1e-6
to 1e-4 range, although anything smaller than 1e-3 is likely to be OK.
This function uses the [`~transformers.onnx.OnnxConfig.generate_dummy_inputs`] method to
generate inputs for the base and exported model, and the absolute tolerance can be
defined in the configuration. We generally find numerical agreement in the 1e-6 to 1e-4
range, although anything smaller than 1e-3 is likely to be OK.
### Contributing a new configuration to 🤗 Transformers
## Contributing a new configuration to 🤗 Transformers
We are looking to expand the set of ready-made configurations and welcome
contributions from the community! If you would like to contribute your addition
to the library, you will need to:
We are looking to expand the set of ready-made configurations and welcome contributions
from the community! If you would like to contribute your addition to the library, you
will need to:
* Implement the ONNX configuration in the corresponding `configuration_<model_name>.py`
file
* Include the model architecture and corresponding features in [`~onnx.features.FeatureManager`]
* Include the model architecture and corresponding features in
[`~onnx.features.FeatureManager`]
* Add your model architecture to the tests in `test_onnx_v2.py`
Check out how the configuration for [IBERT was
contributed](https://github.com/huggingface/transformers/pull/14868/files) to
get an idea of what's involved.
## TorchScript
<Tip>
This is the very beginning of our experiments with TorchScript and we are still exploring its capabilities with
variable-input-size models. It is a focus of interest to us and we will deepen our analysis in upcoming releases,
with more code examples, a more flexible implementation, and benchmarks comparing python-based codes with compiled
TorchScript.
</Tip>
According to Pytorch's documentation: "TorchScript is a way to create serializable and optimizable models from PyTorch
code". Pytorch's two modules [JIT and TRACE](https://pytorch.org/docs/stable/jit.html) allow the developer to export
their model to be re-used in other programs, such as efficiency-oriented C++ programs.
We have provided an interface that allows the export of 🤗 Transformers models to TorchScript so that they can be reused
in a different environment than a Pytorch-based python program. Here we explain how to export and use our models using
TorchScript.
Exporting a model requires two things:
- a forward pass with dummy inputs.
- model instantiation with the `torchscript` flag.
These necessities imply several things developers should be careful about. These are detailed below.
### TorchScript flag and tied weights
This flag is necessary because most of the language models in this repository have tied weights between their
`Embedding` layer and their `Decoding` layer. TorchScript does not allow the export of models that have tied
weights, therefore it is necessary to untie and clone the weights beforehand.
This implies that models instantiated with the `torchscript` flag have their `Embedding` layer and `Decoding`
layer separate, which means that they should not be trained down the line. Training would de-synchronize the two
layers, leading to unexpected results.
This is not the case for models that do not have a Language Model head, as those do not have tied weights. These models
can be safely exported without the `torchscript` flag.
### Dummy inputs and standard lengths
The dummy inputs are used to do a model forward pass. While the inputs' values are propagating through the layers,
Pytorch keeps track of the different operations executed on each tensor. These recorded operations are then used to
create the "trace" of the model.
The trace is created relatively to the inputs' dimensions. It is therefore constrained by the dimensions of the dummy
input, and will not work for any other sequence length or batch size. When trying with a different size, an error such
as:
`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`
will be raised. It is therefore recommended to trace the model with a dummy input size at least as large as the largest
input that will be fed to the model during inference. Padding can be performed to fill the missing values. As the model
will have been traced with a large input size however, the dimensions of the different matrix will be large as well,
resulting in more calculations.
It is recommended to be careful of the total number of operations done on each input and to follow performance closely
when exporting varying sequence-length models.
### Using TorchScript in Python
Below is an example, showing how to save, load models as well as how to use the trace for inference.
#### Saving a model
This snippet shows how to use TorchScript to export a `BertModel`. Here the `BertModel` is instantiated according
to a `BertConfig` class and then saved to disk under the filename `traced_bert.pt`
```python
from transformers import BertModel, BertTokenizer, BertConfig
import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(
vocab_size_or_config_json_file=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
torchscript=True,
)
# Instantiating the model
model = BertModel(config)
# The model needs to be in evaluation mode
model.eval()
# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
```
#### Loading a model
This snippet shows how to load the `BertModel` that was previously saved to disk under the name `traced_bert.pt`.
We are re-using the previously initialised `dummy_input`.
```python
loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
```
#### Using a traced model for inference
Using the traced model for inference is as simple as using its `__call__` dunder method:
```python
traced_model(tokens_tensor, segments_tensors)
```
### Deploying HuggingFace TorchScript models on AWS using the Neuron SDK
AWS introduced the [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/)
instance family for low cost, high performance machine learning inference in the cloud.
The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware accelerator,
specializing in deep learning inferencing workloads.
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#)
is the SDK for Inferentia that supports tracing and optimizing transformers models for
deployment on Inf1. The Neuron SDK provides:
1. Easy-to-use API with one line of code change to trace and optimize a TorchScript model for inference in the cloud.
2. Out of the box performance optimizations for [improved cost-performance](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/benchmark/>)
3. Support for HuggingFace transformers models built with either [PyTorch](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/bert_tutorial/tutorial_pretrained_bert.html)
or [TensorFlow](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/tensorflow/huggingface_bert/huggingface_bert.html).
#### Implications
Transformers Models based on the [BERT (Bidirectional Encoder Representations from Transformers)](https://huggingface.co/docs/transformers/main/model_doc/bert)
architecture, or its variants such as [distilBERT](https://huggingface.co/docs/transformers/main/model_doc/distilbert)
and [roBERTa](https://huggingface.co/docs/transformers/main/model_doc/roberta)
will run best on Inf1 for non-generative tasks such as Extractive Question Answering,
Sequence Classification, Token Classification. Alternatively, text generation
tasks can be adapted to run on Inf1, according to this [AWS Neuron MarianMT tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html).
More information about models that can be converted out of the box on Inferentia can be
found in the [Model Architecture Fit section of the Neuron documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/models/models-inferentia.html#models-inferentia).
#### Dependencies
Using AWS Neuron to convert models requires the following dependencies and environment:
* A [Neuron SDK environment](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/index.html#installation-guide),
which comes pre-configured on [AWS Deep Learning AMI](https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-inferentia-launching.html).
#### Converting a Model for AWS Neuron
Using the same script as in [Using TorchScript in Python](https://huggingface.co/docs/transformers/main/en/serialization#using-torchscript-in-python)
to trace a "BertModel", you import `torch.neuron` framework extension to access
the components of the Neuron SDK through a Python API.
```python
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron
```
And only modify the tracing line of code
from:
```python
torch.jit.trace(model, [tokens_tensor, segments_tensors])
```
to:
```python
torch.neuron.trace(model, [token_tensor, segments_tensors])
```
This change enables Neuron SDK to trace the model and optimize it to run in Inf1 instances.
To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates,
please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).
contributed](https://github.com/huggingface/transformers/pull/14868/files) to get an
idea of what's involved.

View File

@ -0,0 +1,225 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Export to TorchScript
<Tip>
This is the very beginning of our experiments with TorchScript and we are still
exploring its capabilities with variable-input-size models. It is a focus of interest to
us and we will deepen our analysis in upcoming releases, with more code examples, a more
flexible implementation, and benchmarks comparing Python-based codes with compiled
TorchScript.
</Tip>
According to the [TorchScript documentation](https://pytorch.org/docs/stable/jit.html):
> TorchScript is a way to create serializable and optimizable models from PyTorch code.
There are two PyTorch modules, [JIT and
TRACE](https://pytorch.org/docs/stable/jit.html), that allow developers to export their
models to be reused in other programs like efficiency-oriented C++ programs.
We provide an interface that allows you to export 🤗 Transformers models to TorchScript
so they can be reused in a different environment than PyTorch-based Python programs.
Here, we explain how to export and use our models using TorchScript.
Exporting a model requires two things:
- model instantiation with the `torchscript` flag
- a forward pass with dummy inputs
These necessities imply several things developers should be careful about as detailed
below.
## TorchScript flag and tied weights
The `torchscript` flag is necessary because most of the 🤗 Transformers language models
have tied weights between their `Embedding` layer and their `Decoding` layer.
TorchScript does not allow you to export models that have tied weights, so it is
necessary to untie and clone the weights beforehand.
Models instantiated with the `torchscript` flag have their `Embedding` layer and
`Decoding` layer separated, which means that they should not be trained down the line.
Training would desynchronize the two layers, leading to unexpected results.
This is not the case for models that do not have a language model head, as those do not
have tied weights. These models can be safely exported without the `torchscript` flag.
## Dummy inputs and standard lengths
The dummy inputs are used for a models forward pass. While the inputs' values are
propagated through the layers, PyTorch keeps track of the different operations executed
on each tensor. These recorded operations are then used to create the *trace* of the
model.
The trace is created relative to the inputs' dimensions. It is therefore constrained by
the dimensions of the dummy input, and will not work for any other sequence length or
batch size. When trying with a different size, the following error is raised:
```
`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`
```
We recommended you trace the model with a dummy input size at least as large as the
largest input that will be fed to the model during inference. Padding can help fill the
missing values. However, since the model is traced with a larger input size, the
dimensions of the matrix will also be large, resulting in more calculations.
Be careful of the total number of operations done on each input and follow the
performance closely when exporting varying sequence-length models.
## Using TorchScript in Python
This section demonstrates how to save and load models as well as how to use the trace
for inference.
### Saving a model
To export a `BertModel` with TorchScript, instantiate `BertModel` from the `BertConfig`
class and then save it to disk under the filename `traced_bert.pt`:
```python
from transformers import BertModel, BertTokenizer, BertConfig
import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(
vocab_size_or_config_json_file=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
torchscript=True,
)
# Instantiating the model
model = BertModel(config)
# The model needs to be in evaluation mode
model.eval()
# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
```
### Loading a model
Now you can load the previously saved `BertModel`, `traced_bert.pt`, from disk and use
it on the previously initialised `dummy_input`:
```python
loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
```
### Using a traced model for inference
Use the traced model for inference by using its `__call__` dunder method:
```python
traced_model(tokens_tensor, segments_tensors)
```
## Deploy Hugging Face TorchScript models to AWS with the Neuron SDK
AWS introduced the [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/)
instance family for low cost, high performance machine learning inference in the cloud.
The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware
accelerator, specializing in deep learning inferencing workloads. [AWS
Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#) is the SDK for
Inferentia that supports tracing and optimizing transformers models for deployment on
Inf1. The Neuron SDK provides:
1. Easy-to-use API with one line of code change to trace and optimize a TorchScript
model for inference in the cloud.
2. Out of the box performance optimizations for [improved
cost-performance](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/benchmark/>).
3. Support for Hugging Face transformers models built with either
[PyTorch](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/bert_tutorial/tutorial_pretrained_bert.html)
or
[TensorFlow](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/tensorflow/huggingface_bert/huggingface_bert.html).
### Implications
Transformers models based on the [BERT (Bidirectional Encoder Representations from
Transformers)](https://huggingface.co/docs/transformers/main/model_doc/bert)
architecture, or its variants such as
[distilBERT](https://huggingface.co/docs/transformers/main/model_doc/distilbert) and
[roBERTa](https://huggingface.co/docs/transformers/main/model_doc/roberta) run best on
Inf1 for non-generative tasks such as extractive question answering, sequence
classification, and token classification. However, text generation tasks can still be
adapted to run on Inf1 according to this [AWS Neuron MarianMT
tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html).
More information about models that can be converted out of the box on Inferentia can be
found in the [Model Architecture
Fit](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/models/models-inferentia.html#models-inferentia)
section of the Neuron documentation.
### Dependencies
Using AWS Neuron to convert models requires a [Neuron SDK
environment](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/index.html#installation-guide)
which comes preconfigured on [AWS Deep Learning
AMI](https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-inferentia-launching.html).
### Converting a model for AWS Neuron
Convert a model for AWS NEURON using the same code from [Using TorchScript in
Python](serialization#using-torchscript-in-python) to trace a `BertModel`. Import the
`torch.neuron` framework extension to access the components of the Neuron SDK through a
Python API:
```python
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron
```
You only need to modify the following line:
```diff
- torch.jit.trace(model, [tokens_tensor, segments_tensors])
+ torch.neuron.trace(model, [token_tensor, segments_tensors])
```
This enables the Neuron SDK to trace the model and optimize it for Inf1 instances.
To learn more about AWS Neuron SDK features, tools, example tutorials and latest
updates, please see the [AWS NeuronSDK
documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).