mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix convert_graph_to_onnx (#5230)
This commit is contained in:
parent
5543efd5cc
commit
0e1fce3c01
@ -114,15 +114,21 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
|
||||
return input_vars, output_names, dynamic_axes, tokens
|
||||
|
||||
|
||||
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
|
||||
def load_graph_from_args(pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
|
||||
# If no tokenizer provided
|
||||
if tokenizer is None:
|
||||
tokenizer = model
|
||||
|
||||
# Check the wanted framework is available
|
||||
if framework == "pt" and not is_torch_available():
|
||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
if framework == "tf" and not is_tf_available():
|
||||
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
|
||||
|
||||
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
|
||||
|
||||
# Allocate tokenizer and model
|
||||
return pipeline(args.pipeline, model=model, tokenizer=tokenizer, framework=framework)
|
||||
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework)
|
||||
|
||||
|
||||
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
|
||||
@ -154,9 +160,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
|
||||
|
||||
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
|
||||
if not is_tf_available():
|
||||
raise Exception(
|
||||
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
|
||||
)
|
||||
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
|
||||
|
||||
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
|
||||
|
||||
@ -187,11 +191,12 @@ def convert(
|
||||
opset: int,
|
||||
tokenizer: Optional[str] = None,
|
||||
use_external_format: bool = False,
|
||||
pipeline_name: str = "feature-extraction",
|
||||
):
|
||||
print("ONNX opset version set to: {}".format(opset))
|
||||
|
||||
# Load the pipeline
|
||||
nlp = load_graph_from_args(framework, model, tokenizer)
|
||||
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer)
|
||||
|
||||
parent = dirname(output)
|
||||
if not exists(parent):
|
||||
@ -229,7 +234,15 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
# Convert
|
||||
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
|
||||
convert(
|
||||
args.framework,
|
||||
args.model,
|
||||
args.output,
|
||||
args.opset,
|
||||
args.tokenizer,
|
||||
args.use_external_format,
|
||||
args.pipeline,
|
||||
)
|
||||
|
||||
# And verify
|
||||
if args.check_loading:
|
||||
|
Loading…
Reference in New Issue
Block a user