mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 05:40:05 +06:00

* add mega file structure and plain pytorch version of mega source code * added config class with old naming conventions * filled in mega documentation * added config class and embeddings with optional token types * updated notes * starting the conversion process, deleted intermediate and added use_cache back to config * renamed config attributes in modeling_mega.py * checkpointing before refactoring incremental decoding functions * removed stateful incremental key/values for EMA and self-attention * refactored MovingAverageGatedAttention to remove stateful k/v history and use unified attention mask * MovingAverageGatedAttention works with incremental decoding + past values, added sequence length enforcement * more comments in MovingAverageGatedAttention + checkpointing before GatedCrossAttention * bug fix in attention mask handling in MovingAverageGatedAttention * removed incremental state from GatedCrossAttention and removed IncrementalState class * finished gated cross attention and got MegaLayer working * fixed causal masking in mega decoder * fixed how padding and causal masks are passed through MegaLayer with and without k/v caching * finished MegaModel; tested with encoder, decoder-only, and cross-attention type inputs; started work on downstream classes; removed mentions of position_ids * added optional dense hidden layer for masked and causal LM classes * docstring updates in MultiHeadEMA and GatedCrossAttention, removed unnecessary inputs in cross-attention * removed before_attn_fn in Mega class and updated docstrings and comments up to there * bug fix in MovingAverageGatedAttention masking * working conversion of MLM checkpoint in scratchpad script -- perfect matches * moved arg for hidden dense layer in LM head to config; discovered issue where from_pretrained is renaming gamma and beta parameters * renamed gamma and beta parameters to avoid HF renaming when loading from checkpoint * finished checkpoint conversion script * cleanup old class in mega config script * removed 'copied from' statements and passing integration tests * added num_attention_heads=1 to config for integration compatibility, decoder tests working, generation tests failing * fixed tuple output of megamodel * all common tests passing after fixing issues in decoder, gradient retention, and initialization * added mega-specific tests, ready for more documentation and style checks * updated docstrings; checkpoint before style fixes * style and quality checks, fixed initialization problem in float_tensor, ready for PR * added mega to toctree * removed unnecessary arg in megaconfig * removed unused arg and fixed code samples with leftover roberta models * Apply suggestions from code review Applied all suggestions except the one renaming a class, as I'll need to update that througout Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixed issue where .view breaks batch dimension, conversion script fixed with absolute imports, updated readme with Mega->MEGA * removed asserts in Mega code, renamed sequencenorm, gatedcrossattention, and NFFN, replaced get_activation_fn with ACTFN, and added sequencenorm to layer norms * reformatted .forward() docstrings to match style and removed unused mask input in cross-attention * removed all reset_parameters() methods and rolled into MegaPreTrainedModel._init_weights() * renamed all single-letter variables and improved readability in tensor size comments, Mega->MEGA in 2 documentation files * variable names in NFFN * manual Mega->MEGA changes in docs * Mega->MEGA in config auto * style and quality fixes * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * renamed parameters and variables with confusing names, added copied from statements, moved fft conv to its own method, other cleanup from PR comments * commit before dealing with merge conflicts * made new attention activation functions available in ACT2FN and added generation test from OPT * style and quality in activations and tests * documentation fixes, renaming variables in dropout and rotary positions, used built-in causal masking, encoders->layers in MegaModel, moved comments into docstrings * style and quality fixes after latest updates, before rotary position ids * causal mask in MegaBlock docstring + added missing device passing * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * added Mega prefixes where missing, reverted MegaSequenceNorm to if-else, other module renaming requested in PR * style and quality fixes + readme updates pointing to main --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
542 lines
19 KiB
Plaintext
542 lines
19 KiB
Plaintext
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Export to ONNX
|
|
|
|
If you need to deploy 🤗 Transformers models in production environments, we recommend
|
|
exporting them to a serialized format that can be loaded and executed on specialized
|
|
runtimes and hardware. In this guide, we'll show you how to export 🤗 Transformers
|
|
models to [ONNX (Open Neural Network eXchange)](http://onnx.ai).
|
|
|
|
ONNX is an open standard that defines a common set of operators and a common file format
|
|
to represent deep learning models in a wide variety of frameworks, including PyTorch and
|
|
TensorFlow. When a model is exported to the ONNX format, these operators are used to
|
|
construct a computational graph (often called an _intermediate representation_) which
|
|
represents the flow of data through the neural network.
|
|
|
|
By exposing a graph with standardized operators and data types, ONNX makes it easy to
|
|
switch between frameworks. For example, a model trained in PyTorch can be exported to
|
|
ONNX format and then imported in TensorFlow (and vice versa).
|
|
|
|
🤗 Transformers provides a [`transformers.onnx`](main_classes/onnx) package that enables
|
|
you to convert model checkpoints to an ONNX graph by leveraging configuration objects.
|
|
These configuration objects come ready made for a number of model architectures, and are
|
|
designed to be easily extendable to other architectures.
|
|
|
|
<Tip>
|
|
|
|
You can also export 🤗 Transformers models with the [`optimum.exporters.onnx` package](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model)
|
|
from 🤗 Optimum.
|
|
|
|
Once exported, a model can be:
|
|
|
|
- Optimized for inference via techniques such as quantization and graph optimization.
|
|
- Run with ONNX Runtime via [`ORTModelForXXX` classes](https://huggingface.co/docs/optimum/onnxruntime/package_reference/modeling_ort),
|
|
which follow the same `AutoModel` API as the one you are used to in 🤗 Transformers.
|
|
- Run with [optimized inference pipelines](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/pipelines),
|
|
which has the same API as the [`pipeline`] function in 🤗 Transformers.
|
|
|
|
To explore all these features, check out the [🤗 Optimum library](https://github.com/huggingface/optimum).
|
|
|
|
</Tip>
|
|
|
|
Ready-made configurations include the following architectures:
|
|
|
|
<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->
|
|
|
|
- ALBERT
|
|
- BART
|
|
- BEiT
|
|
- BERT
|
|
- BigBird
|
|
- BigBird-Pegasus
|
|
- Blenderbot
|
|
- BlenderbotSmall
|
|
- BLOOM
|
|
- CamemBERT
|
|
- Chinese-CLIP
|
|
- CLIP
|
|
- CodeGen
|
|
- Conditional DETR
|
|
- ConvBERT
|
|
- ConvNeXT
|
|
- Data2VecText
|
|
- Data2VecVision
|
|
- DeBERTa
|
|
- DeBERTa-v2
|
|
- DeiT
|
|
- DETR
|
|
- DistilBERT
|
|
- EfficientNet
|
|
- ELECTRA
|
|
- ERNIE
|
|
- FlauBERT
|
|
- GPT Neo
|
|
- GPT-J
|
|
- GPT-Sw3
|
|
- GroupViT
|
|
- I-BERT
|
|
- ImageGPT
|
|
- LayoutLM
|
|
- LayoutLMv3
|
|
- LeViT
|
|
- Longformer
|
|
- LongT5
|
|
- M2M100
|
|
- Marian
|
|
- mBART
|
|
- MEGA
|
|
- MobileBERT
|
|
- MobileNetV1
|
|
- MobileNetV2
|
|
- MobileViT
|
|
- MT5
|
|
- OpenAI GPT-2
|
|
- OWL-ViT
|
|
- Perceiver
|
|
- PLBart
|
|
- PoolFormer
|
|
- RemBERT
|
|
- ResNet
|
|
- RoBERTa
|
|
- RoBERTa-PreLayerNorm
|
|
- RoFormer
|
|
- SegFormer
|
|
- SqueezeBERT
|
|
- Swin Transformer
|
|
- T5
|
|
- Table Transformer
|
|
- Vision Encoder decoder
|
|
- ViT
|
|
- Whisper
|
|
- X-MOD
|
|
- XLM
|
|
- XLM-RoBERTa
|
|
- XLM-RoBERTa-XL
|
|
- YOLOS
|
|
|
|
In the next two sections, we'll show you how to:
|
|
|
|
* Export a supported model using the `transformers.onnx` package.
|
|
* Export a custom model for an unsupported architecture.
|
|
|
|
## Exporting a model to ONNX
|
|
|
|
<Tip>
|
|
|
|
The recommended way of exporting a model is now to use
|
|
[`optimum.exporters.onnx`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli),
|
|
do not worry it is very similar to `transformers.onnx`!
|
|
|
|
</Tip>
|
|
|
|
To export a 🤗 Transformers model to ONNX, you'll first need to install some extra
|
|
dependencies:
|
|
|
|
```bash
|
|
pip install transformers[onnx]
|
|
```
|
|
|
|
The `transformers.onnx` package can then be used as a Python module:
|
|
|
|
```bash
|
|
python -m transformers.onnx --help
|
|
|
|
usage: Hugging Face Transformers ONNX exporter [-h] -m MODEL [--feature {causal-lm, ...}] [--opset OPSET] [--atol ATOL] output
|
|
|
|
positional arguments:
|
|
output Path indicating where to store generated ONNX model.
|
|
|
|
optional arguments:
|
|
-h, --help show this help message and exit
|
|
-m MODEL, --model MODEL
|
|
Model ID on huggingface.co or path on disk to load model from.
|
|
--feature {causal-lm, ...}
|
|
The type of features to export the model with.
|
|
--opset OPSET ONNX opset version to export the model with.
|
|
--atol ATOL Absolute difference tolerance when validating the model.
|
|
```
|
|
|
|
Exporting a checkpoint using a ready-made configuration can be done as follows:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=distilbert-base-uncased onnx/
|
|
```
|
|
|
|
You should see the following logs:
|
|
|
|
```bash
|
|
Validating ONNX model...
|
|
-[✓] ONNX model output names match reference model ({'last_hidden_state'})
|
|
- Validating ONNX Model output "last_hidden_state":
|
|
-[✓] (2, 8, 768) matches (2, 8, 768)
|
|
-[✓] all values close (atol: 1e-05)
|
|
All good, model saved at: onnx/model.onnx
|
|
```
|
|
|
|
This exports an ONNX graph of the checkpoint defined by the `--model` argument. In this
|
|
example, it is `distilbert-base-uncased`, but it can be any checkpoint on the Hugging
|
|
Face Hub or one that's stored locally.
|
|
|
|
The resulting `model.onnx` file can then be run on one of the [many
|
|
accelerators](https://onnx.ai/supported-tools.html#deployModel) that support the ONNX
|
|
standard. For example, we can load and run the model with [ONNX
|
|
Runtime](https://onnxruntime.ai/) as follows:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer
|
|
>>> from onnxruntime import InferenceSession
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
>>> session = InferenceSession("onnx/model.onnx")
|
|
>>> # ONNX Runtime expects NumPy arrays as input
|
|
>>> inputs = tokenizer("Using DistilBERT with ONNX Runtime!", return_tensors="np")
|
|
>>> outputs = session.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
|
|
```
|
|
|
|
The required output names (like `["last_hidden_state"]`) can be obtained by taking a
|
|
look at the ONNX configuration of each model. For example, for DistilBERT we have:
|
|
|
|
```python
|
|
>>> from transformers.models.distilbert import DistilBertConfig, DistilBertOnnxConfig
|
|
|
|
>>> config = DistilBertConfig()
|
|
>>> onnx_config = DistilBertOnnxConfig(config)
|
|
>>> print(list(onnx_config.outputs.keys()))
|
|
["last_hidden_state"]
|
|
```
|
|
|
|
The process is identical for TensorFlow checkpoints on the Hub. For example, we can
|
|
export a pure TensorFlow checkpoint from the [Keras
|
|
organization](https://huggingface.co/keras-io) as follows:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=keras-io/transformers-qa onnx/
|
|
```
|
|
|
|
To export a model that's stored locally, you'll need to have the model's weights and
|
|
tokenizer files stored in a directory. For example, we can load and save a checkpoint as
|
|
follows:
|
|
|
|
<frameworkcontent> <pt>
|
|
```python
|
|
>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
>>> # Load tokenizer and PyTorch weights form the Hub
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
>>> pt_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
|
|
>>> # Save to disk
|
|
>>> tokenizer.save_pretrained("local-pt-checkpoint")
|
|
>>> pt_model.save_pretrained("local-pt-checkpoint")
|
|
```
|
|
|
|
Once the checkpoint is saved, we can export it to ONNX by pointing the `--model`
|
|
argument of the `transformers.onnx` package to the desired directory:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=local-pt-checkpoint onnx/
|
|
```
|
|
</pt> <tf>
|
|
```python
|
|
>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
|
|
|
|
>>> # Load tokenizer and TensorFlow weights from the Hub
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
|
|
>>> # Save to disk
|
|
>>> tokenizer.save_pretrained("local-tf-checkpoint")
|
|
>>> tf_model.save_pretrained("local-tf-checkpoint")
|
|
```
|
|
|
|
Once the checkpoint is saved, we can export it to ONNX by pointing the `--model`
|
|
argument of the `transformers.onnx` package to the desired directory:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=local-tf-checkpoint onnx/
|
|
```
|
|
</tf> </frameworkcontent>
|
|
|
|
## Selecting features for different model tasks
|
|
|
|
<Tip>
|
|
|
|
The recommended way of exporting a model is now to use `optimum.exporters.onnx`.
|
|
You can check the [🤗 Optimum documentation](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#selecting-a-task)
|
|
to learn how to select a task.
|
|
|
|
</Tip>
|
|
|
|
Each ready-made configuration comes with a set of _features_ that enable you to export
|
|
models for different types of tasks. As shown in the table below, each feature is
|
|
associated with a different `AutoClass`:
|
|
|
|
| Feature | Auto Class |
|
|
| ------------------------------------ | ------------------------------------ |
|
|
| `causal-lm`, `causal-lm-with-past` | `AutoModelForCausalLM` |
|
|
| `default`, `default-with-past` | `AutoModel` |
|
|
| `masked-lm` | `AutoModelForMaskedLM` |
|
|
| `question-answering` | `AutoModelForQuestionAnswering` |
|
|
| `seq2seq-lm`, `seq2seq-lm-with-past` | `AutoModelForSeq2SeqLM` |
|
|
| `sequence-classification` | `AutoModelForSequenceClassification` |
|
|
| `token-classification` | `AutoModelForTokenClassification` |
|
|
|
|
For each configuration, you can find the list of supported features via the
|
|
[`~transformers.onnx.FeaturesManager`]. For example, for DistilBERT we have:
|
|
|
|
```python
|
|
>>> from transformers.onnx.features import FeaturesManager
|
|
|
|
>>> distilbert_features = list(FeaturesManager.get_supported_features_for_model_type("distilbert").keys())
|
|
>>> print(distilbert_features)
|
|
["default", "masked-lm", "causal-lm", "sequence-classification", "token-classification", "question-answering"]
|
|
```
|
|
|
|
You can then pass one of these features to the `--feature` argument in the
|
|
`transformers.onnx` package. For example, to export a text-classification model we can
|
|
pick a fine-tuned model from the Hub and run:
|
|
|
|
```bash
|
|
python -m transformers.onnx --model=distilbert-base-uncased-finetuned-sst-2-english \
|
|
--feature=sequence-classification onnx/
|
|
```
|
|
|
|
This displays the following logs:
|
|
|
|
```bash
|
|
Validating ONNX model...
|
|
-[✓] ONNX model output names match reference model ({'logits'})
|
|
- Validating ONNX Model output "logits":
|
|
-[✓] (2, 2) matches (2, 2)
|
|
-[✓] all values close (atol: 1e-05)
|
|
All good, model saved at: onnx/model.onnx
|
|
```
|
|
|
|
Notice that in this case, the output names from the fine-tuned model are `logits`
|
|
instead of the `last_hidden_state` we saw with the `distilbert-base-uncased` checkpoint
|
|
earlier. This is expected since the fine-tuned model has a sequence classification head.
|
|
|
|
<Tip>
|
|
|
|
The features that have a `with-past` suffix (like `causal-lm-with-past`) correspond to
|
|
model classes with precomputed hidden states (key and values in the attention blocks)
|
|
that can be used for fast autoregressive decoding.
|
|
|
|
</Tip>
|
|
|
|
<Tip>
|
|
|
|
For `VisionEncoderDecoder` type models, the encoder and decoder parts are
|
|
exported separately as two ONNX files named `encoder_model.onnx` and `decoder_model.onnx` respectively.
|
|
|
|
</Tip>
|
|
|
|
|
|
## Exporting a model for an unsupported architecture
|
|
|
|
<Tip>
|
|
|
|
If you wish to contribute by adding support for a model that cannot be currently exported, you should first check if it is
|
|
supported in [`optimum.exporters.onnx`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/package_reference/configuration#supported-architectures),
|
|
and if it is not, [contribute to 🤗 Optimum](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/contribute)
|
|
directly.
|
|
|
|
</Tip>
|
|
|
|
If you wish to export a model whose architecture is not natively supported by the
|
|
library, there are three main steps to follow:
|
|
|
|
1. Implement a custom ONNX configuration.
|
|
2. Export the model to ONNX.
|
|
3. Validate the outputs of the PyTorch and exported models.
|
|
|
|
In this section, we'll look at how DistilBERT was implemented to show what's involved
|
|
with each step.
|
|
|
|
### Implementing a custom ONNX configuration
|
|
|
|
Let's start with the ONNX configuration object. We provide three abstract classes that
|
|
you should inherit from, depending on the type of model architecture you wish to export:
|
|
|
|
* Encoder-based models inherit from [`~onnx.config.OnnxConfig`]
|
|
* Decoder-based models inherit from [`~onnx.config.OnnxConfigWithPast`]
|
|
* Encoder-decoder models inherit from [`~onnx.config.OnnxSeq2SeqConfigWithPast`]
|
|
|
|
<Tip>
|
|
|
|
A good way to implement a custom ONNX configuration is to look at the existing
|
|
implementation in the `configuration_<model_name>.py` file of a similar architecture.
|
|
|
|
</Tip>
|
|
|
|
Since DistilBERT is an encoder-based model, its configuration inherits from
|
|
`OnnxConfig`:
|
|
|
|
```python
|
|
>>> from typing import Mapping, OrderedDict
|
|
>>> from transformers.onnx import OnnxConfig
|
|
|
|
|
|
>>> class DistilBertOnnxConfig(OnnxConfig):
|
|
... @property
|
|
... def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
... return OrderedDict(
|
|
... [
|
|
... ("input_ids", {0: "batch", 1: "sequence"}),
|
|
... ("attention_mask", {0: "batch", 1: "sequence"}),
|
|
... ]
|
|
... )
|
|
```
|
|
|
|
Every configuration object must implement the `inputs` property and return a mapping,
|
|
where each key corresponds to an expected input, and each value indicates the axis of
|
|
that input. For DistilBERT, we can see that two inputs are required: `input_ids` and
|
|
`attention_mask`. These inputs have the same shape of `(batch_size, sequence_length)`
|
|
which is why we see the same axes used in the configuration.
|
|
|
|
<Tip>
|
|
|
|
Notice that `inputs` property for `DistilBertOnnxConfig` returns an `OrderedDict`. This
|
|
ensures that the inputs are matched with their relative position within the
|
|
`PreTrainedModel.forward()` method when tracing the graph. We recommend using an
|
|
`OrderedDict` for the `inputs` and `outputs` properties when implementing custom ONNX
|
|
configurations.
|
|
|
|
</Tip>
|
|
|
|
Once you have implemented an ONNX configuration, you can instantiate it by providing the
|
|
base model's configuration as follows:
|
|
|
|
```python
|
|
>>> from transformers import AutoConfig
|
|
|
|
>>> config = AutoConfig.from_pretrained("distilbert-base-uncased")
|
|
>>> onnx_config = DistilBertOnnxConfig(config)
|
|
```
|
|
|
|
The resulting object has several useful properties. For example, you can view the ONNX
|
|
operator set that will be used during the export:
|
|
|
|
```python
|
|
>>> print(onnx_config.default_onnx_opset)
|
|
11
|
|
```
|
|
|
|
You can also view the outputs associated with the model as follows:
|
|
|
|
```python
|
|
>>> print(onnx_config.outputs)
|
|
OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
|
|
```
|
|
|
|
Notice that the outputs property follows the same structure as the inputs; it returns an
|
|
`OrderedDict` of named outputs and their shapes. The output structure is linked to the
|
|
choice of feature that the configuration is initialised with. By default, the ONNX
|
|
configuration is initialized with the `default` feature that corresponds to exporting a
|
|
model loaded with the `AutoModel` class. If you want to export a model for another task,
|
|
just provide a different feature to the `task` argument when you initialize the ONNX
|
|
configuration. For example, if we wished to export DistilBERT with a sequence
|
|
classification head, we could use:
|
|
|
|
```python
|
|
>>> from transformers import AutoConfig
|
|
|
|
>>> config = AutoConfig.from_pretrained("distilbert-base-uncased")
|
|
>>> onnx_config_for_seq_clf = DistilBertOnnxConfig(config, task="sequence-classification")
|
|
>>> print(onnx_config_for_seq_clf.outputs)
|
|
OrderedDict([('logits', {0: 'batch'})])
|
|
```
|
|
|
|
<Tip>
|
|
|
|
All of the base properties and methods associated with [`~onnx.config.OnnxConfig`] and
|
|
the other configuration classes can be overridden if needed. Check out [`BartOnnxConfig`]
|
|
for an advanced example.
|
|
|
|
</Tip>
|
|
|
|
### Exporting the model
|
|
|
|
Once you have implemented the ONNX configuration, the next step is to export the model.
|
|
Here we can use the `export()` function provided by the `transformers.onnx` package.
|
|
This function expects the ONNX configuration, along with the base model and tokenizer,
|
|
and the path to save the exported file:
|
|
|
|
```python
|
|
>>> from pathlib import Path
|
|
>>> from transformers.onnx import export
|
|
>>> from transformers import AutoTokenizer, AutoModel
|
|
|
|
>>> onnx_path = Path("model.onnx")
|
|
>>> model_ckpt = "distilbert-base-uncased"
|
|
>>> base_model = AutoModel.from_pretrained(model_ckpt)
|
|
>>> tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
|
|
|
>>> onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
|
|
```
|
|
|
|
The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are lists of
|
|
the keys defined in the `inputs` and `outputs` properties of the configuration. Once the
|
|
model is exported, you can test that the model is well formed as follows:
|
|
|
|
```python
|
|
>>> import onnx
|
|
|
|
>>> onnx_model = onnx.load("model.onnx")
|
|
>>> onnx.checker.check_model(onnx_model)
|
|
```
|
|
|
|
<Tip>
|
|
|
|
If your model is larger than 2GB, you will see that many additional files are created
|
|
during the export. This is _expected_ because ONNX uses [Protocol
|
|
Buffers](https://developers.google.com/protocol-buffers/) to store the model and these
|
|
have a size limit of 2GB. See the [ONNX
|
|
documentation](https://github.com/onnx/onnx/blob/master/docs/ExternalData.md) for
|
|
instructions on how to load models with external data.
|
|
|
|
</Tip>
|
|
|
|
### Validating the model outputs
|
|
|
|
The final step is to validate that the outputs from the base and exported model agree
|
|
within some absolute tolerance. Here we can use the `validate_model_outputs()` function
|
|
provided by the `transformers.onnx` package as follows:
|
|
|
|
```python
|
|
>>> from transformers.onnx import validate_model_outputs
|
|
|
|
>>> validate_model_outputs(
|
|
... onnx_config, tokenizer, base_model, onnx_path, onnx_outputs, onnx_config.atol_for_validation
|
|
... )
|
|
```
|
|
|
|
This function uses the [`~transformers.onnx.OnnxConfig.generate_dummy_inputs`] method to
|
|
generate inputs for the base and exported model, and the absolute tolerance can be
|
|
defined in the configuration. We generally find numerical agreement in the 1e-6 to 1e-4
|
|
range, although anything smaller than 1e-3 is likely to be OK.
|
|
|
|
## Contributing a new configuration to 🤗 Transformers
|
|
|
|
We are looking to expand the set of ready-made configurations and welcome contributions
|
|
from the community! If you would like to contribute your addition to the library, you
|
|
will need to:
|
|
|
|
* Implement the ONNX configuration in the corresponding `configuration_<model_name>.py`
|
|
file
|
|
* Include the model architecture and corresponding features in
|
|
[`~onnx.features.FeatureManager`]
|
|
* Add your model architecture to the tests in `test_onnx_v2.py`
|
|
|
|
Check out how the configuration for [IBERT was
|
|
contributed](https://github.com/huggingface/transformers/pull/14868/files) to get an
|
|
idea of what's involved.
|