Fix onnx export input names order (#4641)

* pass on tokenizer to pipeline

* order input names when convert to onnx

* update style

* remove unused imports

* make ordered inputs list needs to be mutable

* add test custom bert model

* remove unused imports
This commit is contained in:
Rens 2020-06-01 16:12:48 +02:00 committed by GitHub
parent bf760c80b5
commit ec62b7d953
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 16 deletions

View File

@ -1,5 +1,4 @@
from argparse import ArgumentParser
from itertools import takewhile
from os import listdir, makedirs
from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple
@ -38,14 +37,17 @@ def ensure_valid_input(model, tokens, input_names):
"""
model_args_name = model.forward.__code__.co_varnames
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
for arg_pos, arg_name in model_args_pos:
model_args[arg_pos] = tokens[arg_name]
ordered_input_names = []
model_args = []
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
if arg_name in input_names:
ordered_input_names.append(arg_name)
model_args.append(tokens[arg_name])
else:
break
model_args = tuple(model_args) # Need to be ordered
return tuple(takewhile(lambda arg: arg is not None, model_args))
return ordered_input_names, tuple(model_args)
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
@ -117,13 +119,13 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format:
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
model_args = ensure_valid_input(nlp.model, tokens, input_names)
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
export(
nlp.model,
model_args,
f=output,
input_names=input_names,
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,

View File

@ -1,7 +1,7 @@
import unittest
from os import sep
from os.path import dirname, exists
from shutil import rmtree
from tempfile import NamedTemporaryFile, TemporaryDirectory
from tests.utils import require_tf, require_torch, slow
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
@ -33,17 +33,34 @@ class OnnxExportTestCase(unittest.TestCase):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 11)
def _test_export(self, model, framework, opset):
@require_torch
@slow
def test_export_custom_bert_model(self):
from transformers import BertModel
vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"]
with NamedTemporaryFile(mode="w+t") as vocab_file:
vocab_file.write("\n".join(vocab))
vocab_file.flush()
tokenizer = BertTokenizerFast(vocab_file.name)
with TemporaryDirectory() as bert_save_dir:
model = BertModel(BertConfig(vocab_size=len(vocab)))
model.save_pretrained(bert_save_dir)
self._test_export(bert_save_dir, "pt", 11, tokenizer)
def _test_export(self, model, framework, opset, tokenizer=None):
try:
# Compute path
path = "onnx" + sep + model + ".onnx"
with TemporaryDirectory() as tempdir:
path = tempdir + "/model.onnx"
# Remove folder if exists
if exists(dirname(path)):
rmtree(dirname(path))
# Export
convert(framework, model, path, opset)
# Export
convert(framework, model, path, opset, tokenizer)
except Exception as e:
self.fail(e)
@ -99,20 +116,25 @@ class OnnxExportTestCase(unittest.TestCase):
# All generated args are valid
input_names = ["input_ids", "attention_mask", "token_type_ids"]
tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]}
inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names)
# Should have exactly the same number of args (all are valid)
self.assertEqual(len(inputs_args), 3)
# Should have exactly the same input names
self.assertEqual(set(ordered_input_names), set(input_names))
# Parameter should be reordered according to their respective place in the function:
# (input_ids, token_type_ids, attention_mask)
self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"]))
# Generated args are interleaved with another args (for instance parameter "past" in GPT2)
inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names)
# Should have exactly the one arg (all before the one not provided "some_other_args")
self.assertEqual(len(inputs_args), 1)
self.assertEqual(len(ordered_input_names), 1)
# Should have only "input_ids"
self.assertEqual(inputs_args[0], tokens["input_ids"])
self.assertEqual(ordered_input_names[0], "input_ids")