mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Adding optimizations block from ONNXRuntime. (#4431)
* Adding optimizations block from ONNXRuntime. * Turn off external data format by default for PyTorch export. * Correct the way use_external_format is passed through the cmdline args.
This commit is contained in:
parent
24538df919
commit
ca4a3f4da9
@ -125,9 +125,39 @@
|
|||||||
"- **Deadcode Elimination**: Remove nodes never accessed in the graph\n",
|
"- **Deadcode Elimination**: Remove nodes never accessed in the graph\n",
|
||||||
"- **Operator Fusing**: Merge multiple instruction into one (Linear -> ReLU can be fused to be LinearReLU)\n",
|
"- **Operator Fusing**: Merge multiple instruction into one (Linear -> ReLU can be fused to be LinearReLU)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"All of this is done on **onnxruntime** by settings specific `SessionOptions`:"
|
"ONNX Runtime automatically applies most optimizations by setting specific `SessionOptions`.\n",
|
||||||
|
"\n",
|
||||||
|
"Note:Some of the latest optimizations that are not yet integrated into ONNX Runtime are available in [optimization script](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers) that tunes models for the best performance."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# # An optional step unless\n",
|
||||||
|
"# # you want to get a model with mixed precision for perf accelartion on newer GPU\n",
|
||||||
|
"# # or you are working with Tensorflow(tf.keras) models or pytorch models other than bert\n",
|
||||||
|
"\n",
|
||||||
|
"# !pip install onnxruntime-tools\n",
|
||||||
|
"# from onnxruntime_tools import optimizer\n",
|
||||||
|
"\n",
|
||||||
|
"# # Mixed precision conversion for bert-base-cased model converted from Pytorch\n",
|
||||||
|
"# optimized_model = optimizer.optimize_model(\"bert-base-cased.onnx\", model_type='bert', num_heads=12, hidden_size=768)\n",
|
||||||
|
"# optimized_model.convert_model_float32_to_float16()\n",
|
||||||
|
"# optimized_model.save_model_to_file(\"bert-base-cased.onnx\")\n",
|
||||||
|
"\n",
|
||||||
|
"# # optimizations for bert-base-cased model converted from Tensorflow(tf.keras)\n",
|
||||||
|
"# optimized_model = optimizer.optimize_model(\"bert-base-cased.onnx\", model_type='bert_keras', num_heads=12, hidden_size=768)\n",
|
||||||
|
"# optimized_model.save_model_to_file(\"bert-base-cased.onnx\")\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
|
@ -22,6 +22,7 @@ class OnnxConverterArgumentParser(ArgumentParser):
|
|||||||
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
|
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
|
||||||
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
|
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
|
||||||
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
|
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
|
||||||
|
self.add_argument("--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb")
|
||||||
self.add_argument("output")
|
self.add_argument("output")
|
||||||
|
|
||||||
|
|
||||||
@ -105,7 +106,7 @@ def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] =
|
|||||||
return pipeline("feature-extraction", model=model, framework=framework)
|
return pipeline("feature-extraction", model=model, framework=framework)
|
||||||
|
|
||||||
|
|
||||||
def convert_pytorch(nlp: Pipeline, opset: int, output: str):
|
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||||
|
|
||||||
@ -126,7 +127,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str):
|
|||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
do_constant_folding=True,
|
do_constant_folding=True,
|
||||||
use_external_data_format=True,
|
use_external_data_format=use_external_format,
|
||||||
enable_onnx_checker=True,
|
enable_onnx_checker=True,
|
||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
)
|
)
|
||||||
@ -160,7 +161,14 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None):
|
def convert(
|
||||||
|
framework: str,
|
||||||
|
model: str,
|
||||||
|
output: str,
|
||||||
|
opset: int,
|
||||||
|
tokenizer: Optional[str] = None,
|
||||||
|
use_external_format: bool = False,
|
||||||
|
):
|
||||||
print("ONNX opset version set to: {}".format(opset))
|
print("ONNX opset version set to: {}".format(opset))
|
||||||
|
|
||||||
# Load the pipeline
|
# Load the pipeline
|
||||||
@ -175,7 +183,7 @@ def convert(framework: str, model: str, output: str, opset: int, tokenizer: Opti
|
|||||||
|
|
||||||
# Export the graph
|
# Export the graph
|
||||||
if framework == "pt":
|
if framework == "pt":
|
||||||
convert_pytorch(nlp, opset, output)
|
convert_pytorch(nlp, opset, output, use_external_format)
|
||||||
else:
|
else:
|
||||||
convert_tensorflow(nlp, opset, output)
|
convert_tensorflow(nlp, opset, output)
|
||||||
|
|
||||||
@ -202,7 +210,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert
|
# Convert
|
||||||
convert(args.framework, args.model, args.output, args.opset, args.tokenizer)
|
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
|
||||||
|
|
||||||
# And verify
|
# And verify
|
||||||
if args.check_loading:
|
if args.check_loading:
|
||||||
|
Loading…
Reference in New Issue
Block a user