mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Conversion script to export transformers models to ONNX IR. (#4253)
* Added generic ONNX conversion script for PyTorch model. * WIP initial TF support. * TensorFlow/Keras ONNX export working. * Print framework version info * Add possibility to check the model is correctly loading on ONNX runtime. * Remove quantization option. * Specify ONNX opset version when exporting. * Formatting. * Remove unused imports. * Make functions more generally reusable from other part of the code. * isort happy. * flake happy * Export only feature-extraction for now * Correctly check inputs order / filter before export. * Removed task variable * Fix invalid args call in load_graph_from_args. * Fix invalid args call in convert. * Fix invalid args call in infer_shapes. * Raise exception and catch in caller function instead of exit. * Add 04-onnx-export.ipynb notebook * More WIP on the notebook * Remove unused imports * Simplify & remove unused constants. * Export with constant_folding in PyTorch * Let's try to put function args in the right order this time ... * Disable external_data_format temporary * ONNX notebook draft ready. * Updated notebooks charts + wording * Correct error while exporting last chart in notebook. * Adressing @LysandreJik comment. * Set ONNX opset to 11 as default value. * Set opset param mandatory * Added ONNX export unittests * Quality. * flake8 happy * Add keras2onnx dependency on extras["tf"] * Pin keras2onnx on github master to v1.6.5 * Second attempt. * Third attempt. * Use the right repo URL this time ... * Do the same for onnxconverter-common * Added keras2onnx and onnxconveter-common to 1.7.0 to supports TF2.2 * Correct commit hash. * Addressing PR review: Optimization are enabled by default. * Addressing PR review: small changes in the notebook * setup.py comment about keras2onnx versioning.
This commit is contained in:
parent
2d05480174
commit
db0076a9df
462
notebooks/04-onnx-export.ipynb
Normal file
462
notebooks/04-onnx-export.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -10,9 +10,10 @@ Pull Request and we'll review it so it can be included here.
|
||||
## Hugging Face's notebooks :hugs:
|
||||
|
||||
| Notebook | Description | |
|
||||
|:----------|:-------------:|------:|
|
||||
|:----------|:-------------|------:|
|
||||
| [Getting Started Tokenizers](https://github.com/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) | How to train and use your very own tokenizer |[](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) |
|
||||
| [Getting Started Transformers](https://github.com/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) | How to easily start using transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) |
|
||||
| [How to use Pipelines](https://github.com/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) | Simple and efficient way to use State-of-the-Art models on downstream tasks through transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) |
|
||||
| [How to train a language model](https://github.com/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)|
|
||||
| [How to generate text](https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)|
|
||||
| [How to export model to ONNX](https://github.com/huggingface/blog/blob/master/notebooks/04-onnx-export.ipynb) | Highlight how to export and run inference workloads through ONNX |
|
14
setup.py
14
setup.py
@ -67,8 +67,18 @@ extras = {}
|
||||
|
||||
extras["mecab"] = ["mecab-python3"]
|
||||
extras["sklearn"] = ["scikit-learn"]
|
||||
extras["tf"] = ["tensorflow"]
|
||||
extras["tf-cpu"] = ["tensorflow-cpu"]
|
||||
|
||||
# keras2onnx and onnxconverter-common version is specific through a commit until 1.7.0 lands on pypi
|
||||
extras["tf"] = [
|
||||
"tensorflow",
|
||||
"onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
|
||||
"keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx"
|
||||
]
|
||||
extras["tf-cpu"] = [
|
||||
"tensorflow-cpu",
|
||||
"onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
|
||||
"keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx"
|
||||
]
|
||||
extras["torch"] = ["torch"]
|
||||
|
||||
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||
|
212
src/transformers/convert_graph_to_onnx.py
Normal file
212
src/transformers/convert_graph_to_onnx.py
Normal file
@ -0,0 +1,212 @@
|
||||
from argparse import ArgumentParser
|
||||
from itertools import takewhile
|
||||
from os import listdir, makedirs
|
||||
from os.path import abspath, dirname, exists
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.pipelines import Pipeline, pipeline
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
|
||||
|
||||
class OnnxConverterArgumentParser(ArgumentParser):
|
||||
"""
|
||||
Wraps all the script arguments supported to export transformers models to ONNX IR
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(OnnxConverterArgumentParser, self).__init__("ONNX Converter")
|
||||
|
||||
self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)")
|
||||
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
|
||||
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
|
||||
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
|
||||
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
|
||||
self.add_argument("output")
|
||||
|
||||
|
||||
def ensure_valid_input(model, tokens, input_names):
|
||||
"""
|
||||
Ensure input are presented in the correct order, without any None
|
||||
Args:
|
||||
model: The model used to forward the input data
|
||||
tokens: BatchEncoding holding the input data
|
||||
input_names: The name of the inputs
|
||||
|
||||
Returns: Tuple
|
||||
|
||||
"""
|
||||
model_args_name = model.forward.__code__.co_varnames
|
||||
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
|
||||
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
|
||||
|
||||
for arg_pos, arg_name in model_args_pos:
|
||||
model_args[arg_pos] = tokens[arg_name]
|
||||
|
||||
model_args = tuple(model_args) # Need to be ordered
|
||||
return tuple(takewhile(lambda arg: arg is not None, model_args))
|
||||
|
||||
|
||||
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
||||
def build_shape_dict(tensor, is_input: bool, seq_len: int):
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return [build_shape_dict(t, is_input, seq_len) for t in tensor]
|
||||
|
||||
else:
|
||||
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
|
||||
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
|
||||
if is_input:
|
||||
if len(tensor.shape) == 2:
|
||||
axes[1] = "sequence"
|
||||
else:
|
||||
raise ValueError("Unable to infer tensor axes ({})".format(len(tensor.shape)))
|
||||
else:
|
||||
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
|
||||
axes.update({dim: "sequence" for dim in seq_axes})
|
||||
|
||||
return axes
|
||||
|
||||
tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
|
||||
seq_len = tokens.input_ids.shape[-1]
|
||||
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
|
||||
|
||||
if not isinstance(outputs, (list, tuple)):
|
||||
outputs = (outputs,)
|
||||
|
||||
# Generate input names & axes
|
||||
input_vars = list(tokens.keys())
|
||||
input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}
|
||||
|
||||
# flatten potentially grouped outputs (past for gpt2, attentions)
|
||||
outputs_flat = []
|
||||
for output in outputs:
|
||||
if isinstance(output, (tuple, list)):
|
||||
outputs_flat.extend(output)
|
||||
else:
|
||||
outputs_flat.append(output)
|
||||
|
||||
# Generate output names & axes
|
||||
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
|
||||
output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
|
||||
|
||||
# Create the aggregated axes representation
|
||||
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
|
||||
return input_vars, output_names, dynamic_axes, tokens
|
||||
|
||||
|
||||
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
|
||||
# If no tokenizer provided
|
||||
if tokenizer is None:
|
||||
tokenizer = model
|
||||
|
||||
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
|
||||
|
||||
# Allocate tokenizer and model
|
||||
return pipeline("feature-extraction", model=model, framework=framework)
|
||||
|
||||
|
||||
def convert_pytorch(nlp: Pipeline, opset: int, output: str):
|
||||
if not is_torch_available():
|
||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
print("PyTorch: {}".format(torch.__version__))
|
||||
|
||||
with torch.no_grad():
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
||||
model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
||||
|
||||
export(
|
||||
nlp.model,
|
||||
model_args,
|
||||
f=output,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=True,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
|
||||
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
|
||||
if not is_tf_available():
|
||||
raise Exception(
|
||||
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
|
||||
)
|
||||
|
||||
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
from keras2onnx import convert_keras, save_model, __version__ as k2ov
|
||||
|
||||
print("TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
|
||||
|
||||
# Build
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
||||
|
||||
# Forward
|
||||
nlp.model.predict(tokens.data)
|
||||
onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)
|
||||
save_model(onnx_model, output)
|
||||
|
||||
except ImportError as e:
|
||||
raise Exception(
|
||||
"Cannot import {} required to convert TF model to ONNX. Please install {} first.".format(e.name, e.name)
|
||||
)
|
||||
|
||||
|
||||
def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None):
|
||||
print("ONNX opset version set to: {}".format(opset))
|
||||
|
||||
# Load the pipeline
|
||||
nlp = load_graph_from_args(framework, model, tokenizer)
|
||||
|
||||
parent = dirname(output)
|
||||
if not exists(parent):
|
||||
print("Creating folder {}".format(parent))
|
||||
makedirs(parent)
|
||||
elif len(listdir(parent)) > 0:
|
||||
raise Exception("Folder {} is not empty, aborting conversion".format(parent))
|
||||
|
||||
# Export the graph
|
||||
if framework == "pt":
|
||||
convert_pytorch(nlp, opset, output)
|
||||
else:
|
||||
convert_tensorflow(nlp, opset, output)
|
||||
|
||||
|
||||
def verify(path: str):
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
|
||||
|
||||
print("Checking ONNX model loading from: {}".format(path))
|
||||
try:
|
||||
onnx_options = SessionOptions()
|
||||
_ = InferenceSession(path, onnx_options, providers=["CPUExecutionProvider"])
|
||||
print("Model correctly loaded")
|
||||
except RuntimeException as re:
|
||||
print("Error while loading the model: {}".format(re))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = OnnxConverterArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Make sure output is absolute path
|
||||
args.output = abspath(args.output)
|
||||
|
||||
try:
|
||||
# Convert
|
||||
convert(args.framework, args.model, args.output, args.opset, args.tokenizer)
|
||||
|
||||
# And verify
|
||||
if args.check_loading:
|
||||
verify(args.output)
|
||||
except Exception as e:
|
||||
print("Error while converting the model: {}".format(e))
|
||||
exit(1)
|
116
tests/test_onnx.py
Normal file
116
tests/test_onnx.py
Normal file
@ -0,0 +1,116 @@
|
||||
import unittest
|
||||
from os import sep
|
||||
from os.path import dirname, exists
|
||||
from shutil import rmtree
|
||||
|
||||
from tests.utils import require_tf, require_torch
|
||||
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
||||
from transformers.convert_graph_to_onnx import convert, ensure_valid_input, infer_shapes
|
||||
|
||||
|
||||
class FuncContiguousArgs:
|
||||
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||
return None
|
||||
|
||||
|
||||
class FuncNonContiguousArgs:
|
||||
def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):
|
||||
return None
|
||||
|
||||
|
||||
class OnnxExportTestCase(unittest.TestCase):
|
||||
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]
|
||||
|
||||
@require_tf
|
||||
def test_export_tensorflow(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "tf", 11)
|
||||
|
||||
@require_torch
|
||||
def test_export_pytorch(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "pt", 11)
|
||||
|
||||
def _test_export(self, model, framework, opset):
|
||||
try:
|
||||
# Compute path
|
||||
path = "onnx" + sep + model + ".onnx"
|
||||
|
||||
# Remove folder if exists
|
||||
if exists(dirname(path)):
|
||||
rmtree(dirname(path))
|
||||
|
||||
# Export
|
||||
convert(framework, model, path, opset)
|
||||
except Exception as e:
|
||||
self.fail(e)
|
||||
|
||||
@require_torch
|
||||
def test_infer_dynamic_axis_pytorch(self):
|
||||
"""
|
||||
Validate the dynamic axis generated for each parameters are correct
|
||||
"""
|
||||
from transformers import BertModel
|
||||
|
||||
model = BertModel(BertConfig.from_pretrained("bert-base-cased"))
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
self._test_infer_dynamic_axis(model, tokenizer, "pt")
|
||||
|
||||
@require_tf
|
||||
def test_infer_dynamic_axis_tf(self):
|
||||
"""
|
||||
Validate the dynamic axis generated for each parameters are correct
|
||||
"""
|
||||
from transformers import TFBertModel
|
||||
|
||||
model = TFBertModel(BertConfig.from_pretrained("bert-base-cased"))
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
self._test_infer_dynamic_axis(model, tokenizer, "tf")
|
||||
|
||||
def _test_infer_dynamic_axis(self, model, tokenizer, framework):
|
||||
nlp = FeatureExtractionPipeline(model, tokenizer)
|
||||
|
||||
variable_names = ["input_ids", "token_type_ids", "attention_mask", "output_0", "output_1"]
|
||||
input_vars, output_vars, shapes, tokens = infer_shapes(nlp, framework)
|
||||
|
||||
# Assert all variables are present
|
||||
self.assertEqual(len(shapes), len(variable_names))
|
||||
self.assertTrue(all([var_name in shapes for var_name in variable_names]))
|
||||
self.assertSequenceEqual(variable_names[:3], input_vars)
|
||||
self.assertSequenceEqual(variable_names[3:], output_vars)
|
||||
|
||||
# Assert inputs are {0: batch, 1: sequence}
|
||||
for var_name in ["input_ids", "token_type_ids", "attention_mask"]:
|
||||
self.assertDictEqual(shapes[var_name], {0: "batch", 1: "sequence"})
|
||||
|
||||
# Assert outputs are {0: batch, 1: sequence} and {0: batch}
|
||||
self.assertDictEqual(shapes["output_0"], {0: "batch", 1: "sequence"})
|
||||
self.assertDictEqual(shapes["output_1"], {0: "batch"})
|
||||
|
||||
def test_ensure_valid_input(self):
|
||||
"""
|
||||
Validate parameters are correctly exported
|
||||
GPT2 has "past" parameter in the middle of input_ids, token_type_ids and attention_mask.
|
||||
ONNX doesn't support export with a dictionary, only a tuple. Thus we need to ensure we remove
|
||||
token_type_ids and attention_mask for now to not having a None tensor in the middle
|
||||
"""
|
||||
# All generated args are valid
|
||||
input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
|
||||
inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
|
||||
|
||||
# Should have exactly the same number of args (all are valid)
|
||||
self.assertEqual(len(inputs_args), 3)
|
||||
|
||||
# Parameter should be reordered according to their respective place in the function:
|
||||
# (input_ids, token_type_ids, attention_mask)
|
||||
self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))
|
||||
|
||||
# Generated args are interleaved with another args (for instance parameter "past" in GPT2)
|
||||
inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
|
||||
|
||||
# Should have exactly the one arg (all before the one not provided "some_other_args")
|
||||
self.assertEqual(len(inputs_args), 1)
|
||||
|
||||
# Should have only "input_ids"
|
||||
self.assertEqual(inputs_args[0], tokens["input_ids"])
|
Loading…
Reference in New Issue
Block a user