From ec62b7d9536eed3f4448d9aa3779e45302dccbd0 Mon Sep 17 00:00:00 2001 From: Rens Date: Mon, 1 Jun 2020 16:12:48 +0200 Subject: [PATCH] Fix onnx export input names order (#4641) * pass on tokenizer to pipeline * order input names when convert to onnx * update style * remove unused imports * make ordered inputs list needs to be mutable * add test custom bert model * remove unused imports --- src/transformers/convert_graph_to_onnx.py | 20 +++++++------ tests/test_onnx.py | 36 ++++++++++++++++++----- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py index 567e03d7211..9389f4777c6 100644 --- a/src/transformers/convert_graph_to_onnx.py +++ b/src/transformers/convert_graph_to_onnx.py @@ -1,5 +1,4 @@ from argparse import ArgumentParser -from itertools import takewhile from os import listdir, makedirs from os.path import abspath, dirname, exists from typing import Dict, List, Optional, Tuple @@ -38,14 +37,17 @@ def ensure_valid_input(model, tokens, input_names): """ model_args_name = model.forward.__code__.co_varnames - model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names] - model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1) - for arg_pos, arg_name in model_args_pos: - model_args[arg_pos] = tokens[arg_name] + ordered_input_names = [] + model_args = [] + for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument + if arg_name in input_names: + ordered_input_names.append(arg_name) + model_args.append(tokens[arg_name]) + else: + break - model_args = tuple(model_args) # Need to be ordered - return tuple(takewhile(lambda arg: arg is not None, model_args)) + return ordered_input_names, tuple(model_args) def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]: @@ -117,13 +119,13 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: with torch.no_grad(): input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt") - model_args = ensure_valid_input(nlp.model, tokens, input_names) + ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names) export( nlp.model, model_args, f=output, - input_names=input_names, + input_names=ordered_input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=True, diff --git a/tests/test_onnx.py b/tests/test_onnx.py index e0565bf7186..98431fd491c 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -1,7 +1,7 @@ import unittest -from os import sep from os.path import dirname, exists from shutil import rmtree +from tempfile import NamedTemporaryFile, TemporaryDirectory from tests.utils import require_tf, require_torch, slow from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline @@ -33,17 +33,34 @@ class OnnxExportTestCase(unittest.TestCase): for model in OnnxExportTestCase.MODEL_TO_TEST: self._test_export(model, "pt", 11) - def _test_export(self, model, framework, opset): + @require_torch + @slow + def test_export_custom_bert_model(self): + from transformers import BertModel + + vocab = ["[UNK]", "[SEP]", "[CLS]", "[PAD]", "[MASK]", "some", "other", "words"] + with NamedTemporaryFile(mode="w+t") as vocab_file: + vocab_file.write("\n".join(vocab)) + vocab_file.flush() + tokenizer = BertTokenizerFast(vocab_file.name) + + with TemporaryDirectory() as bert_save_dir: + model = BertModel(BertConfig(vocab_size=len(vocab))) + model.save_pretrained(bert_save_dir) + self._test_export(bert_save_dir, "pt", 11, tokenizer) + + def _test_export(self, model, framework, opset, tokenizer=None): try: # Compute path - path = "onnx" + sep + model + ".onnx" + with TemporaryDirectory() as tempdir: + path = tempdir + "/model.onnx" # Remove folder if exists if exists(dirname(path)): rmtree(dirname(path)) - # Export - convert(framework, model, path, opset) + # Export + convert(framework, model, path, opset, tokenizer) except Exception as e: self.fail(e) @@ -99,20 +116,25 @@ class OnnxExportTestCase(unittest.TestCase): # All generated args are valid input_names = ["input_ids", "attention_mask", "token_type_ids"] tokens = {"input_ids": [1, 2, 3, 4], "attention_mask": [0, 0, 0, 0], "token_type_ids": [1, 1, 1, 1]} - inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names) + ordered_input_names, inputs_args = ensure_valid_input(FuncContiguousArgs(), tokens, input_names) # Should have exactly the same number of args (all are valid) self.assertEqual(len(inputs_args), 3) + # Should have exactly the same input names + self.assertEqual(set(ordered_input_names), set(input_names)) + # Parameter should be reordered according to their respective place in the function: # (input_ids, token_type_ids, attention_mask) self.assertEqual(inputs_args, (tokens["input_ids"], tokens["token_type_ids"], tokens["attention_mask"])) # Generated args are interleaved with another args (for instance parameter "past" in GPT2) - inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names) + ordered_input_names, inputs_args = ensure_valid_input(FuncNonContiguousArgs(), tokens, input_names) # Should have exactly the one arg (all before the one not provided "some_other_args") self.assertEqual(len(inputs_args), 1) + self.assertEqual(len(ordered_input_names), 1) # Should have only "input_ids" self.assertEqual(inputs_args[0], tokens["input_ids"]) + self.assertEqual(ordered_input_names[0], "input_ids")