Improve ONNX logging (#4999)

* Improve ONNX export logging to give more information about the generated graph.

* Correctly handle input and output in the logging.
This commit is contained in:
Funtowicz Morgan 2020-06-15 11:04:51 +02:00 committed by GitHub
parent 9931f817b7
commit 9ad36ad57f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,24 +36,26 @@ def ensure_valid_input(model, tokens, input_names):
Returns: Tuple
"""
model_args_name = model.forward.__code__.co_varnames
print("Ensuring inputs are in correct order")
ordered_input_names = []
model_args = []
model_args_name = model.forward.__code__.co_varnames
model_args, ordered_input_names = [], []
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
if arg_name in input_names:
ordered_input_names.append(arg_name)
model_args.append(tokens[arg_name])
else:
print("{} is not present in the generated input list.".format(arg_name))
break
print("Generated inputs order: {}".format(ordered_input_names))
return ordered_input_names, tuple(model_args)
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
def build_shape_dict(tensor, is_input: bool, seq_len: int):
def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
if isinstance(tensor, (tuple, list)):
return [build_shape_dict(t, is_input, seq_len) for t in tensor]
return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
else:
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
@ -67,6 +69,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
axes.update({dim: "sequence" for dim in seq_axes})
print("Found {} {} with shape: {}".format("input" if is_input else "output", name, axes))
return axes
tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
@ -78,7 +81,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
# Generate input names & axes
input_vars = list(tokens.keys())
input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}
input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
# flatten potentially grouped outputs (past for gpt2, attentions)
outputs_flat = []
@ -90,7 +93,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
# Generate output names & axes
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
# Create the aggregated axes representation
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
@ -115,7 +118,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
import torch
from torch.onnx import export
print("PyTorch: {}".format(torch.__version__))
print("Using framework PyTorch: {}".format(torch.__version__))
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
@ -147,7 +150,7 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
import tensorflow as tf
from keras2onnx import convert_keras, save_model, __version__ as k2ov
print("TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
print("Using framework TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")