mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Improve ONNX logging (#4999)
* Improve ONNX export logging to give more information about the generated graph. * Correctly handle input and output in the logging.
This commit is contained in:
parent
9931f817b7
commit
9ad36ad57f
@ -36,24 +36,26 @@ def ensure_valid_input(model, tokens, input_names):
|
||||
Returns: Tuple
|
||||
|
||||
"""
|
||||
model_args_name = model.forward.__code__.co_varnames
|
||||
print("Ensuring inputs are in correct order")
|
||||
|
||||
ordered_input_names = []
|
||||
model_args = []
|
||||
model_args_name = model.forward.__code__.co_varnames
|
||||
model_args, ordered_input_names = [], []
|
||||
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
|
||||
if arg_name in input_names:
|
||||
ordered_input_names.append(arg_name)
|
||||
model_args.append(tokens[arg_name])
|
||||
else:
|
||||
print("{} is not present in the generated input list.".format(arg_name))
|
||||
break
|
||||
|
||||
print("Generated inputs order: {}".format(ordered_input_names))
|
||||
return ordered_input_names, tuple(model_args)
|
||||
|
||||
|
||||
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
||||
def build_shape_dict(tensor, is_input: bool, seq_len: int):
|
||||
def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return [build_shape_dict(t, is_input, seq_len) for t in tensor]
|
||||
return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
|
||||
|
||||
else:
|
||||
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
|
||||
@ -67,6 +69,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
|
||||
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
|
||||
axes.update({dim: "sequence" for dim in seq_axes})
|
||||
|
||||
print("Found {} {} with shape: {}".format("input" if is_input else "output", name, axes))
|
||||
return axes
|
||||
|
||||
tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
|
||||
@ -78,7 +81,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
|
||||
|
||||
# Generate input names & axes
|
||||
input_vars = list(tokens.keys())
|
||||
input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}
|
||||
input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
|
||||
|
||||
# flatten potentially grouped outputs (past for gpt2, attentions)
|
||||
outputs_flat = []
|
||||
@ -90,7 +93,7 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
|
||||
|
||||
# Generate output names & axes
|
||||
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
|
||||
output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
|
||||
output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
|
||||
|
||||
# Create the aggregated axes representation
|
||||
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
|
||||
@ -115,7 +118,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
print("PyTorch: {}".format(torch.__version__))
|
||||
print("Using framework PyTorch: {}".format(torch.__version__))
|
||||
|
||||
with torch.no_grad():
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
||||
@ -147,7 +150,7 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
|
||||
import tensorflow as tf
|
||||
from keras2onnx import convert_keras, save_model, __version__ as k2ov
|
||||
|
||||
print("TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
|
||||
print("Using framework TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
|
||||
|
||||
# Build
|
||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
||||
|
Loading…
Reference in New Issue
Block a user