From c8d3fa0dfd191c0272f8de5027430e2fc789b22c Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 15 Feb 2021 13:55:10 +0100 Subject: [PATCH] Check TF ops for ONNX compliance (#10025) * Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut --- src/transformers/file_utils.py | 14 + .../models/gpt2/modeling_tf_gpt2.py | 19 +- src/transformers/testing_utils.py | 8 + tests/test_modeling_tf_albert.py | 1 + tests/test_modeling_tf_bart.py | 2 + tests/test_modeling_tf_bert.py | 2 + tests/test_modeling_tf_blenderbot.py | 1 + tests/test_modeling_tf_blenderbot_small.py | 1 + tests/test_modeling_tf_common.py | 64 ++++- tests/test_modeling_tf_convbert.py | 1 + tests/test_modeling_tf_ctrl.py | 1 + tests/test_modeling_tf_distilbert.py | 1 + tests/test_modeling_tf_dpr.py | 1 + tests/test_modeling_tf_electra.py | 1 + tests/test_modeling_tf_flaubert.py | 1 + tests/test_modeling_tf_funnel.py | 2 + tests/test_modeling_tf_gpt2.py | 2 + tests/test_modeling_tf_led.py | 2 + tests/test_modeling_tf_longformer.py | 2 + tests/test_modeling_tf_lxmert.py | 1 + tests/test_modeling_tf_marian.py | 1 + tests/test_modeling_tf_mbart.py | 1 + tests/test_modeling_tf_mobilebert.py | 1 + tests/test_modeling_tf_mpnet.py | 1 + tests/test_modeling_tf_openai.py | 1 + tests/test_modeling_tf_pegasus.py | 1 + tests/test_modeling_tf_roberta.py | 1 + tests/test_modeling_tf_t5.py | 2 + tests/test_modeling_tf_transfo_xl.py | 1 + tests/test_modeling_tf_xlm.py | 1 + tests/test_modeling_tf_xlnet.py | 1 + utils/check_tf_ops.py | 101 ++++++++ utils/tf_ops/onnx.json | 245 ++++++++++++++++++ 33 files changed, 468 insertions(+), 17 deletions(-) create mode 100644 utils/check_tf_ops.py create mode 100644 utils/tf_ops/onnx.json diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 53562e28f6d..3204270e8ac 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -151,6 +151,16 @@ except importlib_metadata.PackageNotFoundError: _faiss_available = False +_onnx_available = ( + importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None +) +try: + _onxx_version = importlib_metadata.version("onnx") + logger.debug(f"Successfully imported onnx version {_onxx_version}") +except importlib_metadata.PackageNotFoundError: + _onnx_available = False + + _scatter_available = importlib.util.find_spec("torch_scatter") is not None try: _scatter_version = importlib_metadata.version("torch_scatter") @@ -230,6 +240,10 @@ def is_tf_available(): return _tf_available +def is_onnx_available(): + return _onnx_available + + def is_flax_available(): return _flax_available diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 5657140c626..4ed3257466e 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -1030,16 +1030,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific ) - 1 ) - - def get_seq_element(sequence_position, input_batch): - return tf.strided_slice( - input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1] - ) - - result = tf.map_fn( - fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float" - ) - in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]]) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1 logger.warning( @@ -1049,16 +1040,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific loss = None if inputs["labels"] is not None: - if input_ids is not None: - batch_size, sequence_length = shape_list(inputs["input_ids"])[:2] - else: - batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2] assert ( - self.config.pad_token_id is not None or batch_size == 1 + self.config.pad_token_id is not None or logits_shape[0] == 1 ), "Cannot handle batch sizes > 1 if no padding token is defined." if not tf.is_tensor(sequence_lengths): - in_logits = logits[0:batch_size, sequence_lengths] + in_logits = logits[0 : logits_shape[0], sequence_lengths] loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels])) pooled_logits = in_logits if in_logits is not None else logits diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 29403bf21bb..0d16d8c07d0 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -28,6 +28,7 @@ from .file_utils import ( is_datasets_available, is_faiss_available, is_flax_available, + is_onnx_available, is_pandas_available, is_scatter_available, is_sentencepiece_available, @@ -160,6 +161,13 @@ def require_git_lfs(test_case): return test_case +def require_onnx(test_case): + if not is_onnx_available(): + return unittest.skip("test requires ONNX")(test_case) + else: + return test_case + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch. diff --git a/tests/test_modeling_tf_albert.py b/tests/test_modeling_tf_albert.py index 19a2c06d263..e043524b72d 100644 --- a/tests/test_modeling_tf_albert.py +++ b/tests/test_modeling_tf_albert.py @@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFAlbertModelTester(self) diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 5637e5addd9..04ecccec0fe 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_onnx = True + onnx_min_opset = 10 def setUp(self): self.model_tester = TFBartModelTester(self) diff --git a/tests/test_modeling_tf_bert.py b/tests/test_modeling_tf_bert.py index 6043f7926b0..8817ae2bc1c 100644 --- a/tests/test_modeling_tf_bert.py +++ b/tests/test_modeling_tf_bert.py @@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = True + onnx_min_opset = 10 # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index f87de7f7d05..9b99f2aa18d 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_onnx = False def setUp(self): self.model_tester = TFBlenderbotModelTester(self) diff --git a/tests/test_modeling_tf_blenderbot_small.py b/tests/test_modeling_tf_blenderbot_small.py index 582dfc373f8..6fe2e18bea5 100644 --- a/tests/test_modeling_tf_blenderbot_small.py +++ b/tests/test_modeling_tf_blenderbot_small.py @@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_onnx = False def setUp(self): self.model_tester = TFBlenderbotSmallModelTester(self) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index a2524bd98b4..f1eb083cfea 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -16,6 +16,7 @@ import copy import inspect +import json import os import random import tempfile @@ -24,7 +25,7 @@ from importlib import import_module from typing import List, Tuple from transformers import is_tf_available -from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow +from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow if is_tf_available(): @@ -201,6 +202,67 @@ class TFModelTesterMixin: saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") self.assertTrue(os.path.exists(saved_model_dir)) + def test_onnx_compliancy(self): + if not self.test_onnx: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + INTERNAL_OPS = [ + "Assert", + "AssignVariableOp", + "EmptyTensorList", + "ReadVariableOp", + "ResourceGather", + "TruncatedNormal", + "VarHandleOp", + "VarIsInitializedOp", + ] + onnx_ops = [] + + with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f: + onnx_opsets = json.load(f)["opsets"] + + for i in range(1, self.onnx_min_opset + 1): + onnx_ops.extend(onnx_opsets[str(i)]) + + for model_class in self.all_model_classes: + model_op_names = set() + + with tf.Graph().as_default() as g: + model = model_class(config) + model(model.dummy_inputs) + + for op in g.get_operations(): + model_op_names.add(op.node_def.op) + + model_op_names = sorted(model_op_names) + incompatible_ops = [] + + for op in model_op_names: + if op not in onnx_ops and op not in INTERNAL_OPS: + incompatible_ops.append(op) + + self.assertEqual(len(incompatible_ops), 0, incompatible_ops) + + @require_onnx + @slow + def test_onnx_runtime_optimize(self): + if not self.test_onnx: + return + + import keras2onnx + import onnxruntime + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model(model.dummy_inputs) + + onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset) + + onnxruntime.InferenceSession(onnx_model.SerializeToString()) + @slow def test_saved_model_creation_extended(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_convbert.py b/tests/test_modeling_tf_convbert.py index 9ba3e75b225..69128a55ada 100644 --- a/tests/test_modeling_tf_convbert.py +++ b/tests/test_modeling_tf_convbert.py @@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase): ) test_pruning = False test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFConvBertModelTester(self) diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index 09d6cf9fe65..781a7df8514 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else () all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else () test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFCTRLModelTester(self) diff --git a/tests/test_modeling_tf_distilbert.py b/tests/test_modeling_tf_distilbert.py index a10683f9a07..23a8f29d128 100644 --- a/tests/test_modeling_tf_distilbert.py +++ b/tests/test_modeling_tf_distilbert.py @@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): else None ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFDistilBertModelTester(self) diff --git a/tests/test_modeling_tf_dpr.py b/tests/test_modeling_tf_dpr.py index ed37ed3ed23..39e82fd3ab5 100644 --- a/tests/test_modeling_tf_dpr.py +++ b/tests/test_modeling_tf_dpr.py @@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase): test_missing_keys = False test_pruning = False test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFDPRModelTester(self) diff --git a/tests/test_modeling_tf_electra.py b/tests/test_modeling_tf_electra.py index 6cdd0e0c6d8..0f627202361 100644 --- a/tests/test_modeling_tf_electra.py +++ b/tests/test_modeling_tf_electra.py @@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFElectraModelTester(self) diff --git a/tests/test_modeling_tf_flaubert.py b/tests/test_modeling_tf_flaubert.py index 53a899e0bf6..24802cfbabb 100644 --- a/tests/test_modeling_tf_flaubert.py +++ b/tests/test_modeling_tf_flaubert.py @@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): (TFFlaubertWithLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFFlaubertModelTester(self) diff --git a/tests/test_modeling_tf_funnel.py b/tests/test_modeling_tf_funnel.py index e0d50aeb14f..1b8572deac7 100644 --- a/tests/test_modeling_tf_funnel.py +++ b/tests/test_modeling_tf_funnel.py @@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFFunnelModelTester(self) @@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): (TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else () ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFFunnelModelTester(self, base=True) diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 4d9a12384bc..48ca4eef7f7 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ) all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else () test_head_masking = False + test_onnx = True + onnx_min_opset = 10 def setUp(self): self.model_tester = TFGPT2ModelTester(self) diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index 9b20a8136b9..2bfc2f78b86 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFLEDModelTester(self) diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index 951bca5fca3..bb8d94dfdf2 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFLongformerModelTester(self) diff --git a/tests/test_modeling_tf_lxmert.py b/tests/test_modeling_tf_lxmert.py index 31afc66ecb2..f2555acaf42 100644 --- a/tests/test_modeling_tf_lxmert.py +++ b/tests/test_modeling_tf_lxmert.py @@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else () test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFLxmertModelTester(self) diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index 292f4893138..49b2d366cc4 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_onnx = False def setUp(self): self.model_tester = TFMarianModelTester(self) diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index eb0cb553cae..c912c009408 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_onnx = False def setUp(self): self.model_tester = TFMBartModelTester(self) diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py index d98e2fd27e6..02e86f1f8de 100644 --- a/tests/test_modeling_tf_mobilebert.py +++ b/tests/test_modeling_tf_mobilebert.py @@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = False class TFMobileBertModelTester(object): def __init__( diff --git a/tests/test_modeling_tf_mpnet.py b/tests/test_modeling_tf_mpnet.py index 160283350be..d67d68f5d30 100644 --- a/tests/test_modeling_tf_mpnet.py +++ b/tests/test_modeling_tf_mpnet.py @@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFMPNetModelTester(self) diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index 7da10235fb0..87e10584347 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): (TFOpenAIGPTLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFOpenAIGPTModelTester(self) diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index a469aff7fb2..23a8f6aa6cf 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_onnx = False def setUp(self): self.model_tester = TFPegasusModelTester(self) diff --git a/tests/test_modeling_tf_roberta.py b/tests/test_modeling_tf_roberta.py index 66cb128c8ac..d40652efc92 100644 --- a/tests/test_modeling_tf_roberta.py +++ b/tests/test_modeling_tf_roberta.py @@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): else () ) test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFRobertaModelTester(self) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 2d0638f0e41..b611c8553c9 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFT5ModelTester(self) @@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = False all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFT5EncoderOnlyModelTester(self) diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index a903831208f..7232091a0eb 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): # TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented test_resize_embeddings = False test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFTransfoXLModelTester(self) diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index e3eb1bdbc1a..2860c992433 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): (TFXLMWithLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFXLMModelTester(self) diff --git a/tests/test_modeling_tf_xlnet.py b/tests/test_modeling_tf_xlnet.py index f9ea93c21c8..51fba4575fe 100644 --- a/tests/test_modeling_tf_xlnet.py +++ b/tests/test_modeling_tf_xlnet.py @@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): (TFXLNetLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable test_head_masking = False + test_onnx = False def setUp(self): self.model_tester = TFXLNetModelTester(self) diff --git a/utils/check_tf_ops.py b/utils/check_tf_ops.py new file mode 100644 index 00000000000..f6c2b8bae4e --- /dev/null +++ b/utils/check_tf_ops.py @@ -0,0 +1,101 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +from tensorflow.core.protobuf.saved_model_pb2 import SavedModel + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_copies.py +REPO_PATH = "." + +# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model) +INTERNAL_OPS = [ + "Assert", + "AssignVariableOp", + "EmptyTensorList", + "MergeV2Checkpoints", + "ReadVariableOp", + "ResourceGather", + "RestoreV2", + "SaveV2", + "ShardedFilename", + "StatefulPartitionedCall", + "StaticRegexFullMatch", + "VarHandleOp", +] + + +def onnx_compliancy(saved_model_path, strict, opset): + saved_model = SavedModel() + onnx_ops = [] + + with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f: + onnx_opsets = json.load(f)["opsets"] + + for i in range(1, opset + 1): + onnx_ops.extend(onnx_opsets[str(i)]) + + with open(saved_model_path, "rb") as f: + saved_model.ParseFromString(f.read()) + + model_op_names = set() + + # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs) + for meta_graph in saved_model.meta_graphs: + # Add operations in the graph definition + model_op_names.update(node.op for node in meta_graph.graph_def.node) + + # Go through the functions in the graph definition + for func in meta_graph.graph_def.library.function: + # Add operations in each function + model_op_names.update(node.op for node in func.node_def) + + # Convert to list, sorted if you want + model_op_names = sorted(model_op_names) + incompatible_ops = [] + + for op in model_op_names: + if op not in onnx_ops and op not in INTERNAL_OPS: + incompatible_ops.append(op) + + if strict and len(incompatible_ops) > 0: + raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops) + elif len(incompatible_ops) > 0: + print(f"Found the following incompatible ops for the opset {opset}:") + print(*incompatible_ops, sep="\n") + else: + print(f"The saved model {saved_model_path} can properly be converted with ONNX.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).") + parser.add_argument( + "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested." + ) + parser.add_argument( + "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model." + ) + parser.add_argument( + "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)" + ) + args = parser.parse_args() + + if args.framework == "onnx": + onnx_compliancy(args.saved_model_path, args.strict, args.opset) diff --git a/utils/tf_ops/onnx.json b/utils/tf_ops/onnx.json new file mode 100644 index 00000000000..a468145d66e --- /dev/null +++ b/utils/tf_ops/onnx.json @@ -0,0 +1,245 @@ +{ + "opsets": { + "1": [ + "Abs", + "Add", + "AddV2", + "ArgMax", + "ArgMin", + "AvgPool", + "AvgPool3D", + "BatchMatMul", + "BatchMatMulV2", + "BatchToSpaceND", + "BiasAdd", + "BiasAddV1", + "Cast", + "Ceil", + "CheckNumerics", + "ComplexAbs", + "Concat", + "ConcatV2", + "Const", + "ConstV2", + "Conv1D", + "Conv2D", + "Conv2DBackpropInput", + "Conv3D", + "Conv3DBackpropInputV2", + "DepthToSpace", + "DepthwiseConv2d", + "DepthwiseConv2dNative", + "Div", + "Dropout", + "Elu", + "Equal", + "Erf", + "Exp", + "ExpandDims", + "Flatten", + "Floor", + "Gather", + "GatherNd", + "GatherV2", + "Greater", + "Identity", + "IdentityN", + "If", + "LRN", + "LSTMBlockCell", + "LeakyRelu", + "Less", + "Log", + "LogSoftmax", + "LogicalAnd", + "LogicalNot", + "LogicalOr", + "LookupTableSizeV2", + "MatMul", + "Max", + "MaxPool", + "MaxPool3D", + "MaxPoolV2", + "Maximum", + "Mean", + "Min", + "Minimum", + "MirrorPad", + "Mul", + "Neg", + "NoOp", + "NotEqual", + "OneHot", + "Pack", + "Pad", + "PadV2", + "Placeholder", + "PlaceholderV2", + "PlaceholderWithDefault", + "Pow", + "Prod", + "RFFT", + "RandomNormal", + "RandomNormalLike", + "RandomUniform", + "RandomUniformLike", + "RealDiv", + "Reciprocal", + "Relu", + "Relu6", + "Reshape", + "Rsqrt", + "Selu", + "Shape", + "Sigmoid", + "Sign", + "Size", + "Slice", + "Softmax", + "Softplus", + "Softsign", + "SpaceToBatchND", + "SpaceToDepth", + "Split", + "SplitV", + "Sqrt", + "Square", + "SquaredDifference", + "Squeeze", + "StatelessIf", + "StopGradient", + "StridedSlice", + "StringJoin", + "Sub", + "Sum", + "Tanh", + "Tile", + "TopKV2", + "Transpose", + "TruncateDiv", + "Unpack", + "ZerosLike" + ], + "2": [], + "3": [], + "4": [], + "5": [], + "6": [ + "AddN", + "All", + "Any", + "FloorDiv", + "FusedBatchNorm", + "FusedBatchNormV2", + "FusedBatchNormV3" + ], + "7": [ + "Acos", + "Asin", + "Atan", + "Cos", + "Fill", + "FloorMod", + "GreaterEqual", + "LessEqual", + "Loop", + "MatrixBandPart", + "Multinomial", + "Range", + "ResizeBilinear", + "ResizeNearestNeighbor", + "Scan", + "Select", + "SelectV2", + "Sin", + "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + "StatelessWhile", + "Tan", + "TensorListFromTensor", + "TensorListGetItem", + "TensorListLength", + "TensorListReserve", + "TensorListResize", + "TensorListSetItem", + "TensorListStack", + "While" + ], + "8": [ + "BroadcastTo", + "ClipByValue", + "FIFOQueueV2", + "HashTableV2", + "IteratorGetNext", + "IteratorV2", + "LookupTableFindV2", + "MaxPoolWithArgmax", + "QueueDequeueManyV2", + "QueueDequeueUpToV2", + "QueueDequeueV2", + "ReverseSequence" + ], + "9": [ + "SegmentMax", + "SegmentMean", + "SegmentMin", + "SegmentProd", + "SegmentSum", + "Sinh", + "SparseSegmentMean", + "SparseSegmentMeanWithNumSegments", + "SparseSegmentSqrtN", + "SparseSegmentSqrtNWithNumSegments", + "SparseSegmentSum", + "SparseSegmentSumWithNumSegments", + "UnsortedSegmentMax", + "UnsortedSegmentMin", + "UnsortedSegmentProd", + "UnsortedSegmentSum", + "Where" + ], + "10": [ + "CropAndResize", + "CudnnRNN", + "DynamicStitch", + "FakeQuantWithMinMaxArgs", + "IsFinite", + "IsInf", + "NonMaxSuppressionV2", + "NonMaxSuppressionV3", + "NonMaxSuppressionV4", + "NonMaxSuppressionV5", + "ParallelDynamicStitch", + "ReverseV2", + "Roll" + ], + "11": [ + "Bincount", + "Cumsum", + "InvertPermutation", + "LeftShift", + "MatrixDeterminant", + "MatrixDiagPart", + "MatrixDiagPartV2", + "MatrixDiagPartV3", + "RaggedRange", + "RightShift", + "Round", + "ScatterNd", + "SparseFillEmptyRows", + "SparseReshape", + "SparseToDense", + "TensorScatterUpdate", + "Unique" + ], + "12": [ + "Einsum", + "MatrixDiag", + "MatrixDiagV2", + "MatrixDiagV3", + "MatrixSetDiagV3", + "SquaredDistance" + ], + "13": [] + } +} \ No newline at end of file