Conversion script to export transformers models to ONNX IR. (#4253)

* Added generic ONNX conversion script for PyTorch model.

* WIP initial TF support.

* TensorFlow/Keras ONNX export working.

* Print framework version info

* Add possibility to check the model is correctly loading on ONNX runtime.

* Remove quantization option.

* Specify ONNX opset version when exporting.

* Formatting.

* Remove unused imports.

* Make functions more generally reusable from other part of the code.

* isort happy.

* flake happy

* Export only feature-extraction for now

* Correctly check inputs order / filter before export.

* Removed task variable

* Fix invalid args call in load_graph_from_args.

* Fix invalid args call in convert.

* Fix invalid args call in infer_shapes.

* Raise exception and catch in caller function instead of exit.

* Add 04-onnx-export.ipynb notebook

* More WIP on the notebook

* Remove unused imports

* Simplify & remove unused constants.

* Export with constant_folding in PyTorch

* Let's try to put function args in the right order this time ...

* Disable external_data_format temporary

* ONNX notebook draft ready.

* Updated notebooks charts + wording

* Correct error while exporting last chart in notebook.

* Adressing @LysandreJik comment.

* Set ONNX opset to 11 as default value.

* Set opset param mandatory

* Added ONNX export unittests

* Quality.

* flake8 happy

* Add keras2onnx dependency on extras["tf"]

* Pin keras2onnx on github master to v1.6.5

* Second attempt.

* Third attempt.

* Use the right repo URL this time ...

* Do the same for onnxconverter-common

* Added keras2onnx and onnxconveter-common to 1.7.0 to supports TF2.2

* Correct commit hash.

* Addressing PR review: Optimization are enabled by default.

* Addressing PR review: small changes in the notebook

* setup.py comment about keras2onnx versioning.
This commit is contained in:
Funtowicz Morgan 2020-05-14 20:35:52 +00:00 committed by GitHub
parent 2d05480174
commit db0076a9df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 804 additions and 3 deletions

File diff suppressed because one or more lines are too long

View File

@ -10,9 +10,10 @@ Pull Request and we'll review it so it can be included here.
## Hugging Face's notebooks :hugs:
| Notebook | Description | |
|:----------|:-------------:|------:|
|:----------|:-------------|------:|
| [Getting Started Tokenizers](https://github.com/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) | How to train and use your very own tokenizer |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) |
| [Getting Started Transformers](https://github.com/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) | How to easily start using transformers | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) |
| [How to use Pipelines](https://github.com/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) | Simple and efficient way to use State-of-the-Art models on downstream tasks through transformers | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) |
| [How to train a language model](https://github.com/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)|
| [How to generate text](https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)|
| [How to export model to ONNX](https://github.com/huggingface/blog/blob/master/notebooks/04-onnx-export.ipynb) | Highlight how to export and run inference workloads through ONNX |

View File

@ -67,8 +67,18 @@ extras = {}
extras["mecab"] = ["mecab-python3"]
extras["sklearn"] = ["scikit-learn"]
extras["tf"] = ["tensorflow"]
extras["tf-cpu"] = ["tensorflow-cpu"]
# keras2onnx and onnxconverter-common version is specific through a commit until 1.7.0 lands on pypi
extras["tf"] = [
"tensorflow",
"onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
"keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx"
]
extras["tf-cpu"] = [
"tensorflow-cpu",
"onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
"keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx"
]
extras["torch"] = ["torch"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]

View File

@ -0,0 +1,212 @@
from argparse import ArgumentParser
from itertools import takewhile
from os import listdir, makedirs
from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple
from transformers import is_tf_available, is_torch_available
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
class OnnxConverterArgumentParser(ArgumentParser):
"""
Wraps all the script arguments supported to export transformers models to ONNX IR
"""
def __init__(self):
super(OnnxConverterArgumentParser, self).__init__("ONNX Converter")
self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)")
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
self.add_argument("output")
def ensure_valid_input(model, tokens, input_names):
"""
Ensure input are presented in the correct order, without any None
Args:
model: The model used to forward the input data
tokens: BatchEncoding holding the input data
input_names: The name of the inputs
Returns: Tuple
"""
model_args_name = model.forward.__code__.co_varnames
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
for arg_pos, arg_name in model_args_pos:
model_args[arg_pos] = tokens[arg_name]
model_args = tuple(model_args) # Need to be ordered
return tuple(takewhile(lambda arg: arg is not None, model_args))
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
def build_shape_dict(tensor, is_input: bool, seq_len: int):
if isinstance(tensor, (tuple, list)):
return [build_shape_dict(t, is_input, seq_len) for t in tensor]
else:
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
if is_input:
if len(tensor.shape) == 2:
axes[1] = "sequence"
else:
raise ValueError("Unable to infer tensor axes ({})".format(len(tensor.shape)))
else:
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
axes.update({dim: "sequence" for dim in seq_axes})
return axes
tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
seq_len = tokens.input_ids.shape[-1]
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
# Generate input names & axes
input_vars = list(tokens.keys())
input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}
# flatten potentially grouped outputs (past for gpt2, attentions)
outputs_flat = []
for output in outputs:
if isinstance(output, (tuple, list)):
outputs_flat.extend(output)
else:
outputs_flat.append(output)
# Generate output names & axes
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
# Create the aggregated axes representation
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
# If no tokenizer provided
if tokenizer is None:
tokenizer = model
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model
return pipeline("feature-extraction", model=model, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str):
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
print("PyTorch: {}".format(torch.__version__))
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
model_args = ensure_valid_input(nlp.model, tokens, input_names)
export(
nlp.model,
model_args,
f=output,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=True,
enable_onnx_checker=True,
opset_version=opset,
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
if not is_tf_available():
raise Exception(
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
)
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
try:
import tensorflow as tf
from keras2onnx import convert_keras, save_model, __version__ as k2ov
print("TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
# Forward
nlp.model.predict(tokens.data)
onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)
save_model(onnx_model, output)
except ImportError as e:
raise Exception(
"Cannot import {} required to convert TF model to ONNX. Please install {} first.".format(e.name, e.name)
)
def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None):
print("ONNX opset version set to: {}".format(opset))
# Load the pipeline
nlp = load_graph_from_args(framework, model, tokenizer)
parent = dirname(output)
if not exists(parent):
print("Creating folder {}".format(parent))
makedirs(parent)
elif len(listdir(parent)) > 0:
raise Exception("Folder {} is not empty, aborting conversion".format(parent))
# Export the graph
if framework == "pt":
convert_pytorch(nlp, opset, output)
else:
convert_tensorflow(nlp, opset, output)
def verify(path: str):
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
print("Checking ONNX model loading from: {}".format(path))
try:
onnx_options = SessionOptions()
_ = InferenceSession(path, onnx_options, providers=["CPUExecutionProvider"])
print("Model correctly loaded")
except RuntimeException as re:
print("Error while loading the model: {}".format(re))
if __name__ == "__main__":
parser = OnnxConverterArgumentParser()
args = parser.parse_args()
# Make sure output is absolute path
args.output = abspath(args.output)
try:
# Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer)
# And verify
if args.check_loading:
verify(args.output)
except Exception as e:
print("Error while converting the model: {}".format(e))
exit(1)

116
tests/test_onnx.py Normal file
View File

@ -0,0 +1,116 @@
import unittest
from os import sep
from os.path import dirname, exists
from shutil import rmtree
from tests.utils import require_tf, require_torch
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
from transformers.convert_graph_to_onnx import convert, ensure_valid_input, infer_shapes
class FuncContiguousArgs:
def forward(self, input_ids, token_type_ids, attention_mask):
return None
class FuncNonContiguousArgs:
def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):
return None
class OnnxExportTestCase(unittest.TestCase):
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]
@require_tf
def test_export_tensorflow(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "tf", 11)
@require_torch
def test_export_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 11)
def _test_export(self, model, framework, opset):
try:
# Compute path
path = "onnx" + sep + model + ".onnx"
# Remove folder if exists
if exists(dirname(path)):
rmtree(dirname(path))
# Export
convert(framework, model, path, opset)
except Exception as e:
self.fail(e)
@require_torch
def test_infer_dynamic_axis_pytorch(self):
"""
Validate the dynamic axis generated for each parameters are correct
"""
from transformers import BertModel
model = BertModel(BertConfig.from_pretrained("bert-base-cased"))
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
self._test_infer_dynamic_axis(model, tokenizer, "pt")
@require_tf
def test_infer_dynamic_axis_tf(self):
"""
Validate the dynamic axis generated for each parameters are correct
"""
from transformers import TFBertModel
model = TFBertModel(BertConfig.from_pretrained("bert-base-cased"))
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
self._test_infer_dynamic_axis(model, tokenizer, "tf")
def _test_infer_dynamic_axis(self, model, tokenizer, framework):
nlp = FeatureExtractionPipeline(model, tokenizer)
variable_names = ["input_ids", "token_type_ids", "attention_mask", "output_0", "output_1"]
input_vars, output_vars, shapes, tokens = infer_shapes(nlp, framework)
# Assert all variables are present
self.assertEqual(len(shapes), len(variable_names))
self.assertTrue(all([var_name in shapes for var_name in variable_names]))
self.assertSequenceEqual(variable_names[:3], input_vars)
self.assertSequenceEqual(variable_names[3:], output_vars)
# Assert inputs are {0: batch, 1: sequence}
for var_name in ["input_ids", "token_type_ids", "attention_mask"]:
self.assertDictEqual(shapes[var_name], {0: "batch", 1: "sequence"})
# Assert outputs are {0: batch, 1: sequence} and {0: batch}
self.assertDictEqual(shapes["output_0"], {0: "batch", 1: "sequence"})
self.assertDictEqual(shapes["output_1"], {0: "batch"})
def test_ensure_valid_input(self):
"""
Validate parameters are correctly exported
GPT2 has "past" parameter in the middle of input_ids, token_type_ids and attention_mask.
ONNX doesn't support export with a dictionary, only a tuple. Thus we need to ensure we remove
token_type_ids and attention_mask for now to not having a None tensor in the middle
"""
# All generated args are valid
input_names = ["input_ids", "attention_mask", "token_type_ids"]
tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
# Should have exactly the same number of args (all are valid)
self.assertEqual(len(inputs_args), 3)
# Parameter should be reordered according to their respective place in the function:
# (input_ids, token_type_ids, attention_mask)
self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))
# Generated args are interleaved with another args (for instance parameter "past" in GPT2)
inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
# Should have exactly the one arg (all before the one not provided "some_other_args")
self.assertEqual(len(inputs_args), 1)
# Should have only "input_ids"
self.assertEqual(inputs_args[0], tokens["input_ids"])