Fix convert_graph_to_onnx (#5230)

This commit is contained in:
Anthony MOI 2020-06-25 02:17:02 -04:00 committed by GitHub
parent 5543efd5cc
commit 0e1fce3c01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
# If no tokenizer provided
if tokenizer is None:
tokenizer = model
# Check the wanted framework is available
if framework == "pt" and not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
if framework == "tf" and not is_tf_available():
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model
return pipeline(args.pipeline, model=model, tokenizer=tokenizer, framework=framework)
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
if not is_tf_available():
raise Exception(
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
)
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
@ -187,11 +191,12 @@ def convert(
opset: int,
tokenizer: Optional[str] = None,
use_external_format: bool = False,
pipeline_name: str = "feature-extraction",
):
print("ONNX opset version set to: {}".format(opset))
# Load the pipeline
nlp = load_graph_from_args(framework, model, tokenizer)
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
parent = dirname(output)
if not exists(parent):
@ -229,7 +234,15 @@ if __name__ == "__main__":
try:
# Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
convert(
args.framework,
args.model,
args.output,
args.opset,
args.tokenizer,
args.use_external_format,
args.pipeline,
)
# And verify
if args.check_loading: