mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00

* add xlm roberta xl * add convert xlm xl fairseq checkpoint to pytorch * fix init and documents for xlm-roberta-xl * fix indention * add test for XLM-R xl,xxl * fix model hub name * fix some stuff * up * correct init * fix more * fix as suggestions * add torch_device * fix default values of doc strings * fix leftovers * merge to master * up * correct hub names * fix docs * fix model * up * finalize * last fix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add copied from * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
609 lines
24 KiB
Plaintext
609 lines
24 KiB
Plaintext
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Exporting 🤗 Transformers Models
|
|
|
|
If you need to deploy 🤗 Transformers models in production environments, we
|
|
recommend exporting them to a serialized format that can be loaded and executed
|
|
on specialized runtimes and hardware. In this guide, we'll show you how to
|
|
export 🤗 Transformers models in two widely used formats: ONNX and TorchScript.
|
|
|
|
Once exported, a model can optimized for inference via techniques such as
|
|
quantization and pruning. If you are interested in optimizing your models to run
|
|
with maximum efficiency, check out the [🤗 Optimum
|
|
library](https://github.com/huggingface/optimum).
|
|
|
|
## ONNX
|
|
|
|
The [ONNX (Open Neural Network eXchange)](http://onnx.ai) project is an open
|
|
standard that defines a common set of operators and a common file format to
|
|
represent deep learning models in a wide variety of frameworks, including
|
|
PyTorch and TensorFlow. When a model is exported to the ONNX format, these
|
|
operators are used to construct a computational graph (often called an
|
|
_intermediate representation_) which represents the flow of data through the
|
|
neural network.
|
|
|
|
By exposing a graph with standardized operators and data types, ONNX makes it
|
|
easy to switch between frameworks. For example, a model trained in PyTorch can
|
|
be exported to ONNX format and then imported in TensorFlow (and vice versa).
|
|
|
|
🤗 Transformers provides a `transformers.onnx` package that enables you to
|
|
convert model checkpoints to an ONNX graph by leveraging configuration objects.
|
|
These configuration objects come ready made for a number of model architectures,
|
|
and are designed to be easily extendable to other architectures.
|
|
|
|
Ready-made configurations include the following architectures:
|
|
|
|
<!--This table is automatically generated by make style, do not fill manually!-->
|
|
|
|
- ALBERT
|
|
- BART
|
|
- BERT
|
|
- CamemBERT
|
|
- DistilBERT
|
|
- GPT Neo
|
|
- I-BERT
|
|
- LayoutLM
|
|
- Longformer
|
|
- Marian
|
|
- mBART
|
|
- OpenAI GPT-2
|
|
- RoBERTa
|
|
- T5
|
|
- XLM-RoBERTa
|
|
- XLM-RoBERTa-XL
|
|
|
|
The ONNX conversion is supported for the PyTorch versions of the models. If you
|
|
would like to be able to convert a TensorFlow model, please let us know by
|
|
opening an issue.
|
|
|
|
In the next two sections, we'll show you how to:
|
|
|
|
* Export a supported model using the `transformers.onnx` package.
|
|
* Export a custom model for an unsupported architecture.
|
|
|
|
### Exporting a model to ONNX
|
|
|
|
To export a 🤗 Transformers model to ONNX, you'll first need to install some
|
|
extra dependencies:
|
|
|
|
```bash
|
|
pip install transformers[onnx]
|
|
```
|
|
|
|
The `transformers.onnx` package can then be used as a Python module:
|
|
|
|
```bash
|
|
python -m transformers.onnx --help
|
|
|
|
usage: Hugging Face Transformers ONNX exporter [-h] -m MODEL [--feature {causal-lm, ...}] [--opset OPSET] [--atol ATOL] output
|
|
|
|
positional arguments:
|
|
output Path indicating where to store generated ONNX model.
|
|
|
|
optional arguments:
|
|
-h, --help show this help message and exit
|
|
-m MODEL, --model MODEL
|
|
Model ID on huggingface.co or path on disk to load model from.
|
|
--feature {causal-lm, ...}
|
|
The type of features to export the model with.
|
|
--opset OPSET ONNX opset version to export the model with.
|
|
--atol ATOL Absolute difference tolerence when validating the model.
|
|
```
|
|
|
|
Exporting a checkpoint using a ready-made configuration can be done as follows:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=distilbert-base-uncased onnx/
|
|
```
|
|
|
|
which should show the following logs:
|
|
|
|
```bash
|
|
Validating ONNX model...
|
|
-[✓] ONNX model output names match reference model ({'last_hidden_state'})
|
|
- Validating ONNX Model output "last_hidden_state":
|
|
-[✓] (2, 8, 768) matches (2, 8, 768)
|
|
-[✓] all values close (atol: 1e-05)
|
|
All good, model saved at: onnx/model.onnx
|
|
```
|
|
|
|
This exports an ONNX graph of the checkpoint defined by the `--model` argument.
|
|
In this example it is `distilbert-base-uncased`, but it can be any model on the
|
|
Hugging Face Hub or one that's stored locally.
|
|
|
|
The resulting `model.onnx` file can then be run on one of the [many
|
|
accelerators](https://onnx.ai/supported-tools.html#deployModel) that support the
|
|
ONNX standard. For example, we can load and run the model with [ONNX
|
|
Runtime](https://onnxruntime.ai/) as follows:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer
|
|
>>> from onnxruntime import InferenceSession
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
>>> session = InferenceSession("onnx/model.onnx")
|
|
>>> # ONNX Runtime expects NumPy arrays as input
|
|
>>> inputs = tokenizer("Using DistilBERT with ONNX Runtime!", return_tensors="np")
|
|
>>> outputs = session.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
|
|
```
|
|
|
|
The required output names (i.e. `["last_hidden_state"]`) can be obtained by
|
|
taking a look at the ONNX configuration of each model. For example, for
|
|
DistilBERT we have:
|
|
|
|
```python
|
|
>>> from transformers.models.distilbert import DistilBertConfig, DistilBertOnnxConfig
|
|
|
|
>>> config = DistilBertConfig()
|
|
>>> onnx_config = DistilBertOnnxConfig(config)
|
|
>>> print(list(onnx_config.outputs.keys()))
|
|
["last_hidden_state"]
|
|
```
|
|
|
|
### Selecting features for different model topologies
|
|
|
|
Each ready-made configuration comes with a set of _features_ that enable you to
|
|
export models for different types of topologies or tasks. As shown in the table
|
|
below, each feature is associated with a different auto class:
|
|
|
|
| Feature | Auto Class |
|
|
| ------------------------------------ | ------------------------------------ |
|
|
| `causal-lm`, `causal-lm-with-past` | `AutoModelForCausalLM` |
|
|
| `default`, `default-with-past` | `AutoModel` |
|
|
| `masked-lm` | `AutoModelForMaskedLM` |
|
|
| `question-answering` | `AutoModelForQuestionAnswering` |
|
|
| `seq2seq-lm`, `seq2seq-lm-with-past` | `AutoModelForSeq2SeqLM` |
|
|
| `sequence-classification` | `AutoModelForSequenceClassification` |
|
|
| `token-classification` | `AutoModelForTokenClassification` |
|
|
|
|
For each configuration, you can find the list of supported features via the
|
|
`FeaturesManager`. For example, for DistilBERT we have:
|
|
|
|
```python
|
|
>>> from transformers.onnx.features import FeaturesManager
|
|
|
|
>>> distilbert_features = list(FeaturesManager.get_supported_features_for_model_type("distilbert").keys())
|
|
>>> print(distilbert_features)
|
|
["default", "masked-lm", "causal-lm", "sequence-classification", "token-classification", "question-answering"]
|
|
```
|
|
|
|
You can then pass one of these features to the `--feature` argument in the
|
|
`transformers.onnx` package. For example, to export a text-classification model
|
|
we can pick a fine-tuned model from the Hub and run:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=distilbert-base-uncased-finetuned-sst-2-english \
|
|
--feature=sequence-classification onnx/
|
|
```
|
|
|
|
which will display the following logs:
|
|
|
|
```bash
|
|
Validating ONNX model...
|
|
-[✓] ONNX model output names match reference model ({'logits'})
|
|
- Validating ONNX Model output "logits":
|
|
-[✓] (2, 2) matches (2, 2)
|
|
-[✓] all values close (atol: 1e-05)
|
|
All good, model saved at: onnx/model.onnx
|
|
```
|
|
|
|
Notice that in this case, the output names from the fine-tuned model are
|
|
`logits` instead of the `last_hidden_state` we saw with the
|
|
`distilbert-base-uncased` checkpoint earlier. This is expected since the
|
|
fine-tuned model has a sequence classification head.
|
|
|
|
<Tip>
|
|
|
|
The features that have a `with-past` suffix (e.g. `causal-lm-with-past`)
|
|
correspond to model topologies with precomputed hidden states (key and values
|
|
in the attention blocks) that can be used for fast autoregressive decoding.
|
|
|
|
</Tip>
|
|
|
|
|
|
### Exporting a model for an unsupported architecture
|
|
|
|
If you wish to export a model whose architecture is not natively supported by
|
|
the library, there are three main steps to follow:
|
|
|
|
1. Implement a custom ONNX configuration.
|
|
2. Export the model to ONNX.
|
|
3. Validate the outputs of the PyTorch and exported models.
|
|
|
|
In this section, we'll look at how DistilBERT was implemented to show what's
|
|
involved with each step.
|
|
|
|
#### Implementing a custom ONNX configuration
|
|
|
|
Let's start with the ONNX configuration object. We provide three abstract
|
|
classes that you should inherit from, depending on the type of model
|
|
architecture you wish to export:
|
|
|
|
* Encoder-based models inherit from [`~onnx.config.OnnxConfig`]
|
|
* Decoder-based models inherit from [`~onnx.config.OnnxConfigWithPast`]
|
|
* Encoder-decoder models inherit from [`~onnx.config.OnnxSeq2SeqConfigWithPast`]
|
|
|
|
<Tip>
|
|
|
|
A good way to implement a custom ONNX configuration is to look at the existing
|
|
implementation in the `configuration_<model_name>.py` file of a similar architecture.
|
|
|
|
</Tip>
|
|
|
|
Since DistilBERT is an encoder-based model, its configuration inherits from
|
|
`OnnxConfig`:
|
|
|
|
```python
|
|
>>> from typing import Mapping, OrderedDict
|
|
>>> from transformers.onnx import OnnxConfig
|
|
|
|
|
|
>>> class DistilBertOnnxConfig(OnnxConfig):
|
|
... @property
|
|
... def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
... return OrderedDict(
|
|
... [
|
|
... ("input_ids", {0: "batch", 1: "sequence"}),
|
|
... ("attention_mask", {0: "batch", 1: "sequence"}),
|
|
... ]
|
|
... )
|
|
```
|
|
|
|
Every configuration object must implement the `inputs` property and return a
|
|
mapping, where each key corresponds to an expected input, and each value
|
|
indicates the axis of that input. For DistilBERT, we can see that two inputs are
|
|
required: `input_ids` and `attention_mask`. These inputs have the same shape of
|
|
`(batch_size, sequence_length)` which is why we see the same axes used in the
|
|
configuration.
|
|
|
|
<Tip>
|
|
|
|
Notice that `inputs` property for `DistilBertOnnxConfig` returns an
|
|
`OrderedDict`. This ensures that the inputs are matched with their relative
|
|
position within the `PreTrainedModel.forward()` method when tracing the graph.
|
|
We recommend using an `OrderedDict` for the `inputs` and `outputs` properties
|
|
when implementing custom ONNX configurations.
|
|
|
|
</Tip>
|
|
|
|
Once you have implemented an ONNX configuration, you can instantiate it by
|
|
providing the base model's configuration as follows:
|
|
|
|
```python
|
|
>>> from transformers import AutoConfig
|
|
|
|
>>> config = AutoConfig.from_pretrained("distilbert-base-uncased")
|
|
>>> onnx_config = DistilBertOnnxConfig(config)
|
|
```
|
|
|
|
The resulting object has several useful properties. For example you can view the
|
|
ONNX operator set that will be used during the export:
|
|
|
|
```python
|
|
>>> print(onnx_config.default_onnx_opset)
|
|
11
|
|
```
|
|
|
|
You can also view the outputs associated with the model as follows:
|
|
|
|
```python
|
|
>>> print(onnx_config.outputs)
|
|
OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
|
|
```
|
|
|
|
Notice that the outputs property follows the same structure as the inputs; it
|
|
returns an `OrderedDict` of named outputs and their shapes. The output structure
|
|
is linked to the choice of feature that the configuration is initialised with.
|
|
By default, the ONNX configuration is initialized with the `default` feature
|
|
that corresponds to exporting a model loaded with the `AutoModel` class. If you
|
|
want to export a different model topology, just provide a different feature to
|
|
the `task` argument when you initialize the ONNX configuration. For example, if
|
|
we wished to export DistilBERT with a sequence classification head, we could
|
|
use:
|
|
|
|
```python
|
|
>>> from transformers import AutoConfig
|
|
|
|
>>> config = AutoConfig.from_pretrained("distilbert-base-uncased")
|
|
>>> onnx_config_for_seq_clf = DistilBertOnnxConfig(config, task="sequence-classification")
|
|
>>> print(onnx_config_for_seq_clf.outputs)
|
|
OrderedDict([('logits', {0: 'batch'})])
|
|
```
|
|
|
|
<Tip>
|
|
|
|
All of the base properties and methods associated with [`~onnx.config.OnnxConfig`] and the
|
|
other configuration classes can be overriden if needed. Check out
|
|
[`BartOnnxConfig`] for an advanced example.
|
|
|
|
</Tip>
|
|
|
|
#### Exporting the model
|
|
|
|
Once you have implemented the ONNX configuration, the next step is to export the
|
|
model. Here we can use the `export()` function provided by the
|
|
`transformers.onnx` package. This function expects the ONNX configuration, along
|
|
with the base model and tokenizer, and the path to save the exported file:
|
|
|
|
```python
|
|
>>> from pathlib import Path
|
|
>>> from transformers.onnx import export
|
|
>>> from transformers import AutoTokenizer, AutoModel
|
|
|
|
>>> onnx_path = Path("model.onnx")
|
|
>>> model_ckpt = "distilbert-base-uncased"
|
|
>>> base_model = AutoModel.from_pretrained(model_ckpt)
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
|
|
|
>>> onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
|
|
```
|
|
|
|
The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are
|
|
lists of the keys defined in the `inputs` and `outputs` properties of the
|
|
configuration. Once the model is exported, you can test that the model is well
|
|
formed as follows:
|
|
|
|
```python
|
|
>>> import onnx
|
|
|
|
>>> onnx_model = onnx.load("model.onnx")
|
|
>>> onnx.checker.check_model(onnx_model)
|
|
```
|
|
|
|
<Tip>
|
|
|
|
If your model is larger than 2GB, you will see that many additional files are
|
|
created during the export. This is _expected_ because ONNX uses [Protocol
|
|
Buffers](https://developers.google.com/protocol-buffers/) to store the model and
|
|
these have a size limit of 2GB. See the [ONNX
|
|
documentation](https://github.com/onnx/onnx/blob/master/docs/ExternalData.md)
|
|
for instructions on how to load models with external data.
|
|
|
|
</Tip>
|
|
|
|
#### Validating the model outputs
|
|
|
|
The final step is to validate that the outputs from the base and exported model
|
|
agree within some absolute tolerance. Here we can use the
|
|
`validate_model_outputs()` function provided by the `transformers.onnx` package
|
|
as follows:
|
|
|
|
```python
|
|
>>> from transformers.onnx import validate_model_outputs
|
|
|
|
>>> validate_model_outputs(
|
|
... onnx_config, tokenizer, base_model, onnx_path, onnx_outputs, onnx_config.atol_for_validation
|
|
... )
|
|
```
|
|
|
|
This function uses the `OnnxConfig.generate_dummy_inputs()` method to generate
|
|
inputs for the base and exported model, and the absolute tolerance can be
|
|
defined in the configuration. We generally find numerical agreement in the 1e-6
|
|
to 1e-4 range, although anything smaller than 1e-3 is likely to be OK.
|
|
|
|
### Contributing a new configuration to 🤗 Transformers
|
|
|
|
We are looking to expand the set of ready-made configurations and welcome
|
|
contributions from the community! If you would like to contribute your addition
|
|
to the library, you will need to:
|
|
|
|
* Implement the ONNX configuration in the corresponding `configuration_<model_name>.py`
|
|
file
|
|
* Include the model architecture and corresponding features in [`~onnx.features.FeatureManager`]
|
|
* Add your model architecture to the tests in `test_onnx_v2.py`
|
|
|
|
Check out how the configuration for [IBERT was
|
|
contributed](https://github.com/huggingface/transformers/pull/14868/files) to
|
|
get an idea of what's involved.
|
|
|
|
## TorchScript
|
|
|
|
<Tip>
|
|
|
|
This is the very beginning of our experiments with TorchScript and we are still exploring its capabilities with
|
|
variable-input-size models. It is a focus of interest to us and we will deepen our analysis in upcoming releases,
|
|
with more code examples, a more flexible implementation, and benchmarks comparing python-based codes with compiled
|
|
TorchScript.
|
|
|
|
</Tip>
|
|
|
|
According to Pytorch's documentation: "TorchScript is a way to create serializable and optimizable models from PyTorch
|
|
code". Pytorch's two modules [JIT and TRACE](https://pytorch.org/docs/stable/jit.html) allow the developer to export
|
|
their model to be re-used in other programs, such as efficiency-oriented C++ programs.
|
|
|
|
We have provided an interface that allows the export of 🤗 Transformers models to TorchScript so that they can be reused
|
|
in a different environment than a Pytorch-based python program. Here we explain how to export and use our models using
|
|
TorchScript.
|
|
|
|
Exporting a model requires two things:
|
|
|
|
- a forward pass with dummy inputs.
|
|
- model instantiation with the `torchscript` flag.
|
|
|
|
These necessities imply several things developers should be careful about. These are detailed below.
|
|
|
|
|
|
### Implications
|
|
|
|
### TorchScript flag and tied weights
|
|
|
|
This flag is necessary because most of the language models in this repository have tied weights between their
|
|
`Embedding` layer and their `Decoding` layer. TorchScript does not allow the export of models that have tied
|
|
weights, therefore it is necessary to untie and clone the weights beforehand.
|
|
|
|
This implies that models instantiated with the `torchscript` flag have their `Embedding` layer and `Decoding`
|
|
layer separate, which means that they should not be trained down the line. Training would de-synchronize the two
|
|
layers, leading to unexpected results.
|
|
|
|
This is not the case for models that do not have a Language Model head, as those do not have tied weights. These models
|
|
can be safely exported without the `torchscript` flag.
|
|
|
|
### Dummy inputs and standard lengths
|
|
|
|
The dummy inputs are used to do a model forward pass. While the inputs' values are propagating through the layers,
|
|
Pytorch keeps track of the different operations executed on each tensor. These recorded operations are then used to
|
|
create the "trace" of the model.
|
|
|
|
The trace is created relatively to the inputs' dimensions. It is therefore constrained by the dimensions of the dummy
|
|
input, and will not work for any other sequence length or batch size. When trying with a different size, an error such
|
|
as:
|
|
|
|
`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`
|
|
|
|
will be raised. It is therefore recommended to trace the model with a dummy input size at least as large as the largest
|
|
input that will be fed to the model during inference. Padding can be performed to fill the missing values. As the model
|
|
will have been traced with a large input size however, the dimensions of the different matrix will be large as well,
|
|
resulting in more calculations.
|
|
|
|
It is recommended to be careful of the total number of operations done on each input and to follow performance closely
|
|
when exporting varying sequence-length models.
|
|
|
|
### Using TorchScript in Python
|
|
|
|
Below is an example, showing how to save, load models as well as how to use the trace for inference.
|
|
|
|
#### Saving a model
|
|
|
|
This snippet shows how to use TorchScript to export a `BertModel`. Here the `BertModel` is instantiated according
|
|
to a `BertConfig` class and then saved to disk under the filename `traced_bert.pt`
|
|
|
|
```python
|
|
from transformers import BertModel, BertTokenizer, BertConfig
|
|
import torch
|
|
|
|
enc = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
|
|
# Tokenizing input text
|
|
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
|
|
tokenized_text = enc.tokenize(text)
|
|
|
|
# Masking one of the input tokens
|
|
masked_index = 8
|
|
tokenized_text[masked_index] = "[MASK]"
|
|
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
|
|
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
|
|
|
|
# Creating a dummy input
|
|
tokens_tensor = torch.tensor([indexed_tokens])
|
|
segments_tensors = torch.tensor([segments_ids])
|
|
dummy_input = [tokens_tensor, segments_tensors]
|
|
|
|
# Initializing the model with the torchscript flag
|
|
# Flag set to True even though it is not necessary as this model does not have an LM Head.
|
|
config = BertConfig(
|
|
vocab_size_or_config_json_file=32000,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072,
|
|
torchscript=True,
|
|
)
|
|
|
|
# Instantiating the model
|
|
model = BertModel(config)
|
|
|
|
# The model needs to be in evaluation mode
|
|
model.eval()
|
|
|
|
# If you are instantiating the model with *from_pretrained* you can also easily set the TorchScript flag
|
|
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
|
|
|
|
# Creating the trace
|
|
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
|
torch.jit.save(traced_model, "traced_bert.pt")
|
|
```
|
|
|
|
#### Loading a model
|
|
|
|
This snippet shows how to load the `BertModel` that was previously saved to disk under the name `traced_bert.pt`.
|
|
We are re-using the previously initialised `dummy_input`.
|
|
|
|
```python
|
|
loaded_model = torch.jit.load("traced_bert.pt")
|
|
loaded_model.eval()
|
|
|
|
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
|
|
```
|
|
|
|
#### Using a traced model for inference
|
|
|
|
Using the traced model for inference is as simple as using its `__call__` dunder method:
|
|
|
|
```python
|
|
traced_model(tokens_tensor, segments_tensors)
|
|
```
|
|
|
|
### Deploying HuggingFace TorchScript models on AWS using the Neuron SDK
|
|
|
|
AWS introduced the [Amazon EC2 Inf1](https://aws.amazon.com/ec2/instance-types/inf1/)
|
|
instance family for low cost, high performance machine learning inference in the cloud.
|
|
The Inf1 instances are powered by the AWS Inferentia chip, a custom-built hardware accelerator,
|
|
specializing in deep learning inferencing workloads.
|
|
[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/#)
|
|
is the SDK for Inferentia that supports tracing and optimizing transformers models for
|
|
deployment on Inf1. The Neuron SDK provides:
|
|
|
|
|
|
1. Easy-to-use API with one line of code change to trace and optimize a TorchScript model for inference in the cloud.
|
|
2. Out of the box performance optimizations for [improved cost-performance](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/benchmark/>)
|
|
3. Support for HuggingFace transformers models built with either [PyTorch](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/bert_tutorial/tutorial_pretrained_bert.html)
|
|
or [TensorFlow](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/tensorflow/huggingface_bert/huggingface_bert.html).
|
|
|
|
#### Implications
|
|
|
|
Transformers Models based on the [BERT (Bidirectional Encoder Representations from Transformers)](https://huggingface.co/docs/transformers/master/model_doc/bert)
|
|
architecture, or its variants such as [distilBERT](https://huggingface.co/docs/transformers/master/model_doc/distilbert)
|
|
and [roBERTa](https://huggingface.co/docs/transformers/master/model_doc/roberta)
|
|
will run best on Inf1 for non-generative tasks such as Extractive Question Answering,
|
|
Sequence Classification, Token Classification. Alternatively, text generation
|
|
tasks can be adapted to run on Inf1, according to this [AWS Neuron MarianMT tutorial](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/transformers-marianmt.html).
|
|
More information about models that can be converted out of the box on Inferentia can be
|
|
found in the [Model Architecture Fit section of the Neuron documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/models/models-inferentia.html#models-inferentia).
|
|
|
|
#### Dependencies
|
|
|
|
Using AWS Neuron to convert models requires the following dependencies and environment:
|
|
|
|
* A [Neuron SDK environment](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/neuron-frameworks/pytorch-neuron/index.html#installation-guide),
|
|
which comes pre-configured on [AWS Deep Learning AMI](https://docs.aws.amazon.com/dlami/latest/devguide/tutorial-inferentia-launching.html).
|
|
|
|
#### Converting a Model for AWS Neuron
|
|
|
|
Using the same script as in [Using TorchScript in Python](https://huggingface.co/docs/transformers/master/en/serialization#using-torchscript-in-python)
|
|
to trace a "BertModel", you import `torch.neuron` framework extension to access
|
|
the components of the Neuron SDK through a Python API.
|
|
|
|
```python
|
|
from transformers import BertModel, BertTokenizer, BertConfig
|
|
import torch
|
|
import torch.neuron
|
|
```
|
|
And only modify the tracing line of code
|
|
|
|
from:
|
|
|
|
```python
|
|
torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
|
```
|
|
|
|
to:
|
|
|
|
```python
|
|
torch.neuron.trace(model, [token_tensor, segments_tensors])
|
|
```
|
|
|
|
This change enables Neuron SDK to trace the model and optimize it to run in Inf1 instances.
|
|
|
|
To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates,
|
|
please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html).
|