diff --git a/notebooks/04-onnx-export.ipynb b/notebooks/04-onnx-export.ipynb index 1bb64ae52b8..4666097c30a 100644 --- a/notebooks/04-onnx-export.ipynb +++ b/notebooks/04-onnx-export.ipynb @@ -125,9 +125,39 @@ "- **Deadcode Elimination**: Remove nodes never accessed in the graph\n", "- **Operator Fusing**: Merge multiple instruction into one (Linear -> ReLU can be fused to be LinearReLU)\n", "\n", - "All of this is done on **onnxruntime** by settings specific `SessionOptions`:" + "ONNX Runtime automatically applies most optimizations by setting specific `SessionOptions`.\n", + "\n", + "Note:Some of the latest optimizations that are not yet integrated into ONNX Runtime are available in [optimization script](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) that tunes models for the best performance." ] }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# # An optional step unless\n", + "# # you want to get a model with mixed precision for perf accelartion on newer GPU\n", + "# # or you are working with Tensorflow(tf.keras) models or pytorch models other than bert\n", + "\n", + "# !pip install onnxruntime-tools\n", + "# from onnxruntime_tools import optimizer\n", + "\n", + "# # Mixed precision conversion for bert-base-cased model converted from Pytorch\n", + "# optimized_model = optimizer.optimize_model(\"bert-base-cased.onnx\", model_type='bert', num_heads=12, hidden_size=768)\n", + "# optimized_model.convert_model_float32_to_float16()\n", + "# optimized_model.save_model_to_file(\"bert-base-cased.onnx\")\n", + "\n", + "# # optimizations for bert-base-cased model converted from Tensorflow(tf.keras)\n", + "# optimized_model = optimizer.optimize_model(\"bert-base-cased.onnx\", model_type='bert_keras', num_heads=12, hidden_size=768)\n", + "# optimized_model.save_model_to_file(\"bert-base-cased.onnx\")\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, { "cell_type": "code", "execution_count": 2, diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py index fd0787a55f8..7afe974e08c 100644 --- a/src/transformers/convert_graph_to_onnx.py +++ b/src/transformers/convert_graph_to_onnx.py @@ -22,6 +22,7 @@ class OnnxConverterArgumentParser(ArgumentParser): self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model") self.add_argument("--opset", type=int, default=11, help="ONNX opset to use") self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model") + self.add_argument("--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb") self.add_argument("output") @@ -105,7 +106,7 @@ def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = return pipeline("feature-extraction", model=model, framework=framework) -def convert_pytorch(nlp: Pipeline, opset: int, output: str): +def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool): if not is_torch_available(): raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") @@ -126,7 +127,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str): output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=True, - use_external_data_format=True, + use_external_data_format=use_external_format, enable_onnx_checker=True, opset_version=opset, ) @@ -160,7 +161,14 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str): ) -def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None): +def convert( + framework: str, + model: str, + output: str, + opset: int, + tokenizer: Optional[str] = None, + use_external_format: bool = False, +): print("ONNX opset version set to: {}".format(opset)) # Load the pipeline @@ -175,7 +183,7 @@ def convert(framework: str, model: str, output: str, opset: int, tokenizer: Opti # Export the graph if framework == "pt": - convert_pytorch(nlp, opset, output) + convert_pytorch(nlp, opset, output, use_external_format) else: convert_tensorflow(nlp, opset, output) @@ -202,7 +210,7 @@ if __name__ == "__main__": try: # Convert - convert(args.framework, args.model, args.output, args.opset, args.tokenizer) + convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format) # And verify if args.check_loading: